group_sharded.py 11.5 KB
Newer Older
B
Baibaifan 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
B
Baibaifan 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
B
Baibaifan 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
B
Baibaifan 已提交
9 10 11 12 13 14 15 16 17 18 19 20
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import logging

import paddle

from paddle.optimizer import Optimizer
R
Roc 已提交
21
from paddle.distributed.utils.log_utils import get_logger
B
Baibaifan 已提交
22 23 24
from paddle.fluid.framework import in_dygraph_mode

# Old version
B
Baibaifan 已提交
25 26 27 28 29
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler

B
Baibaifan 已提交
30 31 32 33 34 35
# New version
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import GroupShardedOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import GroupShardedStage2
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import GroupShardedStage3
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import GroupShardedScaler

H
hong 已提交
36
logger_ = get_logger(logging.WARNING)
B
Baibaifan 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49


def group_sharded_parallel(model,
                           optimizer,
                           level,
                           scaler=None,
                           group=None,
                           offload=False,
                           sync_buffers=False,
                           buffer_max_size=2**23,
                           segment_size=2**20,
                           sync_comm=False):
    """
B
Baibaifan 已提交
50 51
    Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
    Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation.
B
Baibaifan 已提交
52 53 54 55 56

    Args:
        model (Layer): The layer to be wrapped with group_sharded_parallel.
        optimizer (Optimizer): The optimizer to be wrapped with group_sharded_parallel.
        level (str): The different level of the group sharded. Such as `os`, `os_g`, `p_g_os`.
B
Baibaifan 已提交
57 58 59 60 61 62 63
        scaler (GradScaler, optional): If AMP is used, you need to pass GradScaler. Defaults to None, indicating that GradScaler is not used.
        group (Group, optional): The group instance. Defaults to None, indicating that the default environment group is used.
        offload (bool, optional): Whether to use the offload function. Defaults to False, which means that the offload function is not used.
        sync_buffers (bool, optional): Whether to broadcast model buffers. It is generally used when there are registered model buffers. Defaults to False, indicating that model buffers are not used.
        buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23.
        segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
        sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
64

B
Baibaifan 已提交
65 66 67 68
    Returns:
        model: A wrapper for group sharded given model.
        optimizer: A wrapper for group sharded given optimizer.
        scaler: A wrapper for group sharded given scaler.
69

B
Baibaifan 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            from paddle.fluid.dygraph.nn import Linear
            from paddle.distributed import fleet
            from paddle.distributed.sharding import group_sharded_parallel

            fleet.init(is_collective=True)
            group = paddle.distributed.new_group([0, 1])
            model = Linear(1000, 1000)

            clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
            optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)

            # wrap sharding model, optimizer and scaler
            model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler)

            img, label = data
            label.stop_gradient = True
            img.stop_gradient = True

            out = model(img)
            loss = paddle.nn.functional.cross_entropy(input=out, label=label)

            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
    """
    # check optition type
    assert isinstance(
        model,
        paddle.nn.Layer), "The model must be the instance of paddle.nn.Layer."
    assert isinstance(
        optimizer, Optimizer
    ), "The optimizer must be the instance of paddle.optimizer.Optimizer."
107 108
    assert level in ['os', 'os_g',
                     'p_g_os'], "The level must be os, os_g or p_g_os."
B
Baibaifan 已提交
109 110 111 112

    def check_dtype(param):
        return param.dtype == paddle.float16

B
Baibaifan 已提交
113
    params_fp16 = list(filter(check_dtype, model.parameters()))
B
Baibaifan 已提交
114 115 116 117 118 119 120
    if scaler is None and len(params_fp16) > 0:
        raise ValueError("Please enter the correct scaler.")
    # convert model/optimizer/scaler
    if level in ['os', 'os_g']:
        logger_.info("*" * 30)
        logger_.info("Sharded level os uses sharded level os_g achieved now.")
        logger_.info("*" * 30)
B
Baibaifan 已提交
121 122 123 124 125 126
        if in_dygraph_mode():
            optimizer = GroupShardedOptimizerStage2(
                params=optimizer._parameter_list,
                optim=optimizer,
                group=group,
                offload=offload)
127 128 129 130 131
            model = GroupShardedStage2(model,
                                       optimizer,
                                       group=group,
                                       sync_buffers=sync_buffers,
                                       buffer_max_size=buffer_max_size)
B
Baibaifan 已提交
132
        else:
133 134 135 136 137 138 139 140 141
            optimizer = ShardingOptimizerStage2(params=model.parameters(),
                                                optim=optimizer,
                                                group=group,
                                                offload=offload)
            model = ShardingStage2(model,
                                   optimizer,
                                   group=group,
                                   sync_buffers=sync_buffers,
                                   buffer_max_size=buffer_max_size)
B
Baibaifan 已提交
142
    elif level == 'p_g_os':
B
Baibaifan 已提交
143
        if in_dygraph_mode():
144 145 146 147 148 149 150
            model = GroupShardedStage3(model,
                                       optimizer=optimizer,
                                       group=group,
                                       sync_buffers=sync_buffers,
                                       segment_size=segment_size,
                                       offload=offload,
                                       sync_comm=sync_comm)
B
Baibaifan 已提交
151
        else:
152 153 154 155 156 157 158
            model = ShardingStage3(model,
                                   optimizer=optimizer,
                                   group=group,
                                   sync_buffers=sync_buffers,
                                   segment_size=segment_size,
                                   offload=offload,
                                   sync_comm=sync_comm)
B
Baibaifan 已提交
159 160
    else:
        raise ValueError("Please enter the correct level.")
H
Haohongxiang 已提交
161
    if isinstance(scaler, paddle.amp.GradScaler):
B
Baibaifan 已提交
162 163 164 165
        if in_dygraph_mode():
            scaler = GroupShardedScaler(scaler)
        else:
            scaler = ShardingScaler(scaler)
B
Baibaifan 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178
    logger_.info("*" * 30)
    logger_.info(
        "If there is a communication hang using group sharded, please check whether the communication operations of each process are unified."
    )
    logger_.info("*" * 30)

    return model, optimizer, scaler


def save_group_sharded_model(model, output, optimizer=None):
    """
    Group sharded encapsulated model and optimizer state saving module.

179
    Note:
B
Baibaifan 已提交
180 181
        If using save_group_sharded_model saves the model. When loading again, you need to set the model or optimizer state before using group_sharded_parallel.

B
Baibaifan 已提交
182 183 184
    Args:
        model (Layer): A wrapper for group sharded given model.
        output (str): Save directory.
B
Baibaifan 已提交
185
        optimizer (Optimizer, optional): Group sharded encapsulated optimizer. Defaults to None, indicating that the optimizer state is not saved.
186

B
Baibaifan 已提交
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            from paddle.fluid.dygraph.nn import Linear
            from paddle.distributed import fleet
            from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model

            fleet.init(is_collective=True)
            group = paddle.distributed.new_group([0, 1])
            model = Linear(1000, 1000)

            clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
            optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)

            # wrap sharding model, optimizer and scaler
            model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler)

            img, label = data
            label.stop_gradient = True
            img.stop_gradient = True

            out = model(img)
            loss = paddle.nn.functional.cross_entropy(input=out, label=label)

            loss.backward()
            optimizer.step()
            optimizer.clear_grad()

            # save model and optimizer state_dict
B
Baibaifan 已提交
218
            save_group_sharded_model(model, optimizer, output=output_dir)
B
Baibaifan 已提交
219 220 221 222 223 224 225 226
    """
    logger_.info(
        "==========Begin to save group sharded model and optimizer==========")
    assert not os.path.isfile(
        output
    ), "Saving directory ({}) should be a directory, not a file".format(output)
    os.makedirs(output, exist_ok=True)
    output_model = os.path.join(output, "model.pdmodel")
B
Baibaifan 已提交
227
    if isinstance(model, (ShardingStage2, GroupShardedStage2)):
B
Baibaifan 已提交
228
        paddle.save(model._layer.state_dict(), output_model)
B
Baibaifan 已提交
229
    elif isinstance(model, (ShardingStage3, GroupShardedStage3)):
B
Baibaifan 已提交
230 231 232 233 234
        convert2cpu = True if model._offload else False
        model.get_all_parameters(convert2cpu=convert2cpu)
        paddle.save(model._layer.state_dict(), output_model)
    else:
        raise ValueError(
235 236
            "Please use the layer which is wrapped with group_sharded_parallel."
        )
B
Baibaifan 已提交
237 238 239 240 241 242 243 244 245

    if optimizer is not None:
        assert hasattr(
            optimizer, "_optim"
        ), "Please use the optimizer which is wrapped with group_sharded_parallel."
        output_opt = os.path.join(output, "model.pdopt")
        paddle.save(optimizer._optim.state_dict(), output_opt)
    logger_.info(
        "==========End to save group sharded model and optimizer==========")