dygraph_recompute_hybrid.py 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np

import paddle
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils import recompute
from paddle.incubate.distributed.fleet import recompute_hybrid
import random
from paddle.distributed import fleet

import paddle.fluid.layers as layers


def get_fc_block(block_idx, input_size, is_last=False):
    block_name = "block_" + str(block_idx)
    block = paddle.nn.Sequential(
        (block_name + "_fc_0",
         paddle.nn.Linear(input_size, input_size, bias_attr=False)),
        (block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
        (block_name + "_relu_1", paddle.nn.ReLU()),
        (block_name + "_fc_1",
         paddle.nn.Linear(input_size, input_size, bias_attr=False)),
        (block_name + "_relu_2", paddle.nn.ReLU()),
    )
    if is_last:
        block.add_sublayer(block_name + "_fc_2",
                           paddle.nn.Linear(input_size, 1,
                                            bias_attr=False))  # add sublayer
    else:
        block.add_sublayer(block_name + "_fc_2",
                           paddle.nn.Linear(input_size,
                                            input_size,
                                            bias_attr=False))  # add sublayer
    return block


class Naive_fc_net(paddle.nn.Layer):

    def __init__(self,
                 input_size=10,
                 recompute_blocks=[1, 3],
                 offload=False,
                 partition=False,
                 recompute_kwargs={}):
        super(Naive_fc_net, self).__init__()
        self.recompute_blocks = recompute_blocks
        self.recompute_kwargs = recompute_kwargs
        self.offload = offload
        self.partition = partition

        self.runfunc0 = get_fc_block(0, input_size, is_last=False)
        self.runfunc1 = get_fc_block(1, input_size, is_last=False)
        self.runfunc2 = get_fc_block(2, input_size, is_last=False)
        self.runfunc3 = get_fc_block(3, input_size, is_last=False)
        self.runfunc4 = get_fc_block(4, input_size, is_last=True)

        self.layers = [
            self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3,
            self.runfunc4
        ]

    def forward(self, inputs):
        for i in range(len(self.layers)):
            if i in self.recompute_blocks:
                inputs = recompute_hybrid(
                    {
                        "mp_group": fleet.fleet._hcg.get_model_parallel_group(),
                        "offload": self.offload,
                        "partition": self.partition
                    }, self.layers[i], inputs, **self.recompute_kwargs)
            else:
                inputs = self.layers[i](inputs)

        return inputs


def run_model(recompute_block=[],
              recompute_kwargs={},
              offload=False,
              partition=False,
              enable_autocast=False,
              pure_fp16=False):
    gen = paddle.seed(10)
    gen.manual_seed(10)
    np.random.seed(10)
    random.seed(10)

    batch_size, input_size = 1, 10
    model = Naive_fc_net(input_size,
                         recompute_blocks=recompute_block,
                         offload=offload,
                         partition=partition,
                         recompute_kwargs=recompute_kwargs)
    loss_fn = paddle.nn.MSELoss(reduction='mean')
    optimizer = paddle.optimizer.SGD(learning_rate=0.01,
                                     parameters=model.parameters())

    model = fleet.distributed_model(model)
    optimizer = fleet.distributed_optimizer(optimizer)

    if enable_autocast:
        scaler = paddle.amp.GradScaler()
        scaler = fleet.distributed_scaler(scaler)

    loss_ = []
    param_ = []
    grad_ = []
    for step in range(10):

        x_data = np.random.randn(batch_size, input_size).astype(np.float32)
        x = paddle.to_tensor(x_data)
        # x.stop_gradient = False
        level = 'O2' if pure_fp16 else 'O1'
        with paddle.amp.auto_cast(True, level=level):
            y_pred = model(x)
            loss = y_pred.mean()
        if enable_autocast:
            scaler.scale(loss).backward()
            scaler.minimize(optimizer, loss)
        else:
            loss_.append(np.asarray(loss).tolist())
            loss.backward()
            optimizer.step()

        param_.append(np.asarray(model.parameters()[9]).tolist())
        grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())

        optimizer.clear_grad()
    return loss_, param_, grad_


class TestPyLayer(unittest.TestCase):

    def setUp(self):
        strategy = fleet.DistributedStrategy()
        self.model_parallel_size = 2
        self.data_parallel_size = 1
        self.pipeline_parallel_size = 1
        strategy.hybrid_configs = {
            "dp_degree": self.data_parallel_size,
            "mp_degree": self.model_parallel_size,
            "pp_degree": self.pipeline_parallel_size,
        }
        fleet.init(is_collective=True, strategy=strategy)

    def test_base_case(self, enable_autocast=False, pure_fp16=False):

        def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
            self.assertEqual(loss_ref, loss)
            self.assertEqual(param_ref, param)
            self.assertEqual(grad_ref, grad)

        # without recompute
        loss_ref, param_ref, grad_ref = run_model(
            recompute_block=[],
            enable_autocast=enable_autocast,
            pure_fp16=pure_fp16)

        # with recompute, offload=False, partition=False
        loss, param, grad = run_model(recompute_block=[1, 3],
                                      enable_autocast=enable_autocast,
                                      pure_fp16=pure_fp16)
        check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

        # with recompute, offload=True, partition=False
        loss, param, grad = run_model(recompute_block=[1, 2, 3],
                                      offload=True,
                                      enable_autocast=enable_autocast,
                                      pure_fp16=pure_fp16)
        check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

        # with recompute, offload=False, partition=True
        loss, param, grad = run_model(recompute_block=[1],
                                      partition=True,
                                      enable_autocast=enable_autocast,
                                      pure_fp16=pure_fp16)
        check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

        # with recompute, offload=True, partition=True
        loss, param, grad = run_model(recompute_block=[1, 3, 4],
                                      offload=True,
                                      partition=True,
                                      enable_autocast=enable_autocast,
                                      pure_fp16=pure_fp16)
        check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

    def test_fc_net_with_dropout(self):
        self.test_base_case()

    def test_fc_net_with_amp(self):
        self.test_base_case(enable_autocast=True)

    def test_fc_net_with_fp16(self):
        self.test_base_case(enable_autocast=True, pure_fp16=True)

    def test_recompute_kwargs(self):
        paddle.set_device("gpu")
        kwargs = {"is_test": False}
        with self.assertRaises(TypeError):
            loss_ref, param_ref, grad_ref = run_model(recompute_block=[2],
                                                      recompute_kwargs=kwargs)


if __name__ == '__main__':
    unittest.main()