dygraph_sharding_stage3.py 7.9 KB
Newer Older
B
Baibaifan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
# -*- coding: UTF-8 -*-

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import argparse
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn

from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler

epoch = 10
batch_size = 32
paddle.seed(2021)
np.random.seed(2021)
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
fleet.init(is_collective=True)


class MLP(fluid.Layer):
    def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
        super(MLP, self).__init__()

        self._linear1 = Linear(linear_size, linear_size)
        self._linear2 = Linear(linear_size, linear_size)
        self._linear3 = Linear(linear_size, 10)

    def forward(self, inputs):
        y = self._linear1(inputs)
        y = self._linear2(y)
        y = self._linear3(y)
        return y


def reader_decorator(linear_size=1000):
    def __reader__():
        for _ in range(100):
            img = np.random.rand(linear_size).astype('float32')
            label = np.ones(1).astype('int64')
            yield img, label

    return __reader__


def optimizer_setting(model, use_pure_fp16, opt_group=False):
    clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
    optimizer = paddle.optimizer.AdamW(
        parameters=[{
            "params": model.parameters()
        }] if opt_group else model.parameters(),
        learning_rate=0.001,
        weight_decay=0.00001,
        grad_clip=clip,
        multi_precision=use_pure_fp16)

    return optimizer


def train_mlp(model,
              sharding_stage,
              use_pure_fp16=False,
              accumulate_grad=False,
              opt_group=False,
              recompute=False):
    group = paddle.distributed.new_group([0, 1])
    if opt_group:
        optimizer = optimizer_setting(
            model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group)
    else:
        optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)

    if use_pure_fp16:
        model = paddle.amp.decorate(
            models=model, level='O2', save_dtype='float32')
        scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
        scaler = ShardingScaler(scaler)
    if sharding_stage == 2:
        optimizer = ShardingOptimizerStage2(
            params=model.parameters(), optim=optimizer, group=group)
        model = ShardingStage2(
            model,
            optimizer,
            group=group,
            buffer_max_size=2**21,
            accumulate_grads=accumulate_grad)
    elif sharding_stage == 3:
        model = ShardingStage3(
            model, optimizer=optimizer, group=group, sync_comm=recompute)

    train_reader = paddle.batch(
        reader_decorator(), batch_size=batch_size, drop_last=True)

    train_loader = paddle.io.DataLoader.from_generator(
        capacity=32,
        use_double_buffer=True,
        iterable=True,
        return_list=True,
        use_multiprocess=True)
    train_loader.set_sample_list_generator(train_reader)

    for eop in range(epoch):
        model.train()
        for batch_id, data in enumerate(train_loader()):
            img, label = data
            label.stop_gradient = True
            img.stop_gradient = True
            with paddle.amp.auto_cast(True, level='O2'):
                out = model(img)
                loss = paddle.nn.functional.cross_entropy(
                    input=out, label=label)
            avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
            if not accumulate_grad:
                if not use_pure_fp16:
                    avg_loss.backward()
                    optimizer.step()
                else:
                    scaler.scale(avg_loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                optimizer.clear_grad()
        if accumulate_grad:
            if not use_pure_fp16:
                avg_loss.backward()
                optimizer.step()
            else:
                scaler.scale(avg_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            optimizer.clear_grad()
    if sharding_stage == 3:
        model.get_all_parameters()
    return model.parameters()


def test_stage2_stage3():
    mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = MLP(), MLP(), MLP(
    ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
    state_dict = mlp.state_dict()
    mlp1.set_state_dict(state_dict)
    mlp2.set_state_dict(state_dict)
    mlp3.set_state_dict(state_dict)
    mlp4.set_state_dict(state_dict)
    mlp5.set_state_dict(state_dict)
    mlp6.set_state_dict(state_dict)
    mlp7.set_state_dict(state_dict)
    mlp8.set_state_dict(state_dict)
    # fp32 
    stage2_params = train_mlp(
        mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=True)
    stage3_params = train_mlp(
        mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=True)
    for i in range(len(stage2_params)):
        for j in range(len(stage3_params)):
            if stage2_params[i].name == stage3_params[j].name:
                np.testing.assert_allclose(
                    stage2_params[i].numpy(),
                    stage3_params[j].numpy(),
                    rtol=1e-6)
    # fp32 accumulate grad
    stage2_params = train_mlp(
        mlp3,
        sharding_stage=2,
        use_pure_fp16=False,
        accumulate_grad=True,
        opt_group=True)
    stage3_params = train_mlp(
        mlp4,
        sharding_stage=3,
        use_pure_fp16=False,
        accumulate_grad=True,
        opt_group=True)
    for i in range(len(stage2_params)):
        for j in range(len(stage3_params)):
            if stage2_params[i].name == stage3_params[j].name:
                np.testing.assert_allclose(
                    stage2_params[i].numpy(),
                    stage3_params[j].numpy(),
                    rtol=1e-6)
    # fp16
    stage2_params = train_mlp(
        mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False)
    stage3_params = train_mlp(
        mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False)
    for i in range(len(stage2_params)):
        for j in range(len(stage3_params)):
            if stage2_params[i].name == stage3_params[j].name:
                np.testing.assert_allclose(
                    stage2_params[i].numpy(),
                    stage3_params[j].numpy(),
                    rtol=1e-6)
    # fp16 recompute
    stage3_params = train_mlp(
        mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False)
    stage3_params_re = train_mlp(
        mlp8,
        sharding_stage=3,
        use_pure_fp16=True,
        opt_group=False,
        recompute=True)
    for i in range(len(stage3_params)):
        for j in range(len(stage3_params_re)):
            if stage3_params[i].name == stage3_params_re[j].name:
                np.testing.assert_allclose(
                    stage3_params[i].numpy(),
                    stage3_params_re[j].numpy(),
                    rtol=1e-6)
    return


if __name__ == '__main__':
    test_stage2_stage3()