hybrid_parallel_pp_bf16.py 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import unittest

import numpy as np
T
tianshuo78520a 已提交
19
from legacy_test.hybrid_parallel_pp_layer import AlexNet, AlexNetPipeDesc
20 21 22

import paddle
import paddle.distributed as dist
23
from paddle.distributed import fleet
24 25 26 27 28 29 30 31 32 33 34 35 36 37
from paddle.distributed.utils.nccl_utils import check_nccl_version_for_bf16


def set_random_seed(seed, dp_id, rank_id):
    """Set random seed for reproducability."""
    random.seed(seed)
    np.random.seed(seed + dp_id)
    paddle.seed(seed + dp_id)


batch_size = 4
micro_batch_size = 2


C
co63oc 已提交
38
class TestDistPPTraining(unittest.TestCase):
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
    def setUp(self):
        strategy = fleet.DistributedStrategy()
        self.model_parallel_size = 1
        self.data_parallel_size = 1
        self.pipeline_parallel_size = 2
        strategy.hybrid_configs = {
            "dp_degree": self.data_parallel_size,
            "mp_degree": self.model_parallel_size,
            "pp_degree": self.pipeline_parallel_size,
        }
        strategy.pipeline_configs = {
            "accumulate_steps": batch_size // micro_batch_size,
            "micro_batch_size": micro_batch_size,
        }
        fleet.init(is_collective=True, strategy=strategy)

    def test_pp_model(self):
        hcg = fleet.get_hybrid_communicate_group()
        word_size = hcg.get_model_parallel_world_size()
        dp_id = hcg.get_data_parallel_rank()
        pp_id = hcg.get_stage_id()
        rank_id = dist.get_rank()
        set_random_seed(1024, dp_id, rank_id)

        grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0)

        # construct model a
        model_a = AlexNet(10)
        scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
            boundaries=[2], values=[0.001, 0.002], verbose=True
        )
        optimizer_a = paddle.optimizer.SGD(
            learning_rate=scheduler_a,
            grad_clip=grad_clip,
            parameters=model_a.parameters(),
        )

        scaler_a = paddle.amp.GradScaler(
            init_loss_scaling=1, use_dynamic_loss_scaling=False
        )

        # construct model b
        model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
        scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
            boundaries=[2], values=[0.001, 0.002], verbose=True
        )
        optimizer_b = paddle.optimizer.SGD(
            learning_rate=scheduler_b,
            grad_clip=grad_clip,
            parameters=model_b.parameters(),
        )

        param_len = len(model_a.parameters())
        parameters = []
        for param in model_a.parameters():
            parameters.append(param.numpy())

        for idx, param in enumerate(model_b.parameters()):
            param.set_value(parameters[idx + pp_id * (param_len // 2)])

        model_a, optimizer_a = paddle.amp.decorate(
            models=model_a,
            optimizers=optimizer_a,
            level='O2',
            dtype='bfloat16',
            save_dtype='float32',
        )
        model_b, optimizer_b = paddle.amp.decorate(
            models=model_b,
            optimizers=optimizer_b,
            level='O2',
            dtype='bfloat16',
            save_dtype='float32',
        )

        model_b = fleet.distributed_model(model_b)
        optimizer_b = fleet.distributed_optimizer(optimizer_b)
        scaler_b = paddle.amp.GradScaler(
            init_loss_scaling=1, use_dynamic_loss_scaling=False
        )
        scaler_b = fleet.distributed_scaler(scaler_b)

        # construct reader
        train_reader = paddle.batch(
            paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True
        )

        for step_id, data in enumerate(train_reader()):
            x_data = (
                np.array([x[0] for x in data])
                .astype('float32')
                .reshape(batch_size, 1, 28, 28)
            )
            y_data = (
                np.array([x[1] for x in data])
                .astype('int64')
                .reshape(batch_size, 1)
            )
            img = paddle.to_tensor(x_data)
            label = paddle.to_tensor(y_data)
            img.stop_gradient = True
            label.stop_gradient = True

            if step_id >= 5:
                return True

            with paddle.amp.auto_cast(
                enable=True, dtype='bfloat16', level='O2'
            ):
                loss_a = model_a(img, label)
            scaler_a.scale(loss_a).backward()
            scaler_a.minimize(optimizer_a, loss_a)
            optimizer_a.clear_grad()
            scheduler_a.step()

            with paddle.amp.auto_cast(
                enable=True, dtype='bfloat16', level='O2'
            ):
                loss_b = model_b.train_batch(
                    [img, label], optimizer_b, scheduler_b, scaler=scaler_b
                )

            print("loss: ", loss_a.numpy(), loss_b.numpy())
            np.testing.assert_allclose(
                loss_a.numpy(), loss_b.numpy(), rtol=5e-3
            )


if __name__ == "__main__":
168 169 170 171
    if (
        check_nccl_version_for_bf16()
        and paddle.device.cuda.get_device_properties().major >= 8
    ):
172
        unittest.main()