test_he_lora.py 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import math
import torch
import torch.nn.functional as F
import pytest
import deepspeed
12
from deepspeed.runtime.zero import GatheredParameters
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
from deepspeed.ops.op_builder import OpBuilder
from deepspeed.utils import safe_get_full_grad
import numpy.testing as npt
from unit.common import DistributedTest

from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM)

rocm_version = OpBuilder.installed_rocm_version()
if rocm_version != (0, 0):
    pytest.skip("skip inference tests on rocm for now", allow_module_level=True)


def to_device(batch, device):
    output = {}
    for k, v in batch.items():
        try:
            output[k] = v.to(device)
        except:
            output[k] = v
    return output


def convert_linear_layer_to_lora(model, part_module_name, lora_dim=0, lora_scaling=1, lora_droppout=0):
    from deepspeed.compression.helper import recursive_getattr, recursive_setattr

    repalce_name = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear) and part_module_name in name:
            repalce_name.append(name)
    for name in repalce_name:
        module = recursive_getattr(model, name)
        tmp = LinearLayer_LoRA(module.weight, lora_dim, lora_scaling, lora_droppout,
                               module.bias).to(module.weight.device).to(module.weight.dtype)
        recursive_setattr(model, name, tmp)
    return model


class LinearLayer_LoRA(torch.nn.Module):
    # an simple implementation of LoRA
    # for now only support Linear Layer
    def __init__(self, weight, lora_dim=0, lora_scaling=1, lora_droppout=0, bias=None):
        super(LinearLayer_LoRA, self).__init__()
        self.weight = weight
        self.bias = bias

        if lora_dim <= 0:
            raise ValueError("You are training to use LoRA, whose reduced dim should be larger than 1")

        try:
            # for zero stage 3
            rows, columns = weight.ds_shape
        except:
            rows, columns = weight.shape
        self.lora_right_weight = torch.nn.Parameter(torch.zeros(
            columns, lora_dim))  # apply transpose so in forward we do not need to transpose again
        self.lora_left_weight = torch.nn.Parameter(torch.zeros(lora_dim, rows))
        self.lora_scaling = lora_scaling / lora_dim

        if lora_droppout > 0:
            self.lora_dropout = torch.nn.Dropout(lora_droppout)
        else:
            self.lora_dropout = torch.nn.Identity()

        self.reset_parameters()
        # disable the original weight gradient
        self.weight.requires_grad = False
        # fuse LoRA to the original weight
        self.fuse_lora = False

    def eval(self):
        self.lora_dropout.eval()

    def train(self, mode=True):
        self.lora_dropout.train(mode)

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
        torch.nn.init.zeros_(self.lora_left_weight)

    def forward(self, input):
        if self.fuse_lora:
            return F.linear(input, self.weight, self.bias)
        else:
            return F.linear(input, self.weight, self.bias) + (
                self.lora_dropout(input) @ self.lora_right_weight @ self.lora_left_weight) * self.lora_scaling


def only_optimize_lora_parameters(model):
    # turn off the gradient of all the parameters except the LoRA parameters
    for name, param in model.named_parameters():
        if "lora_right_weight" in name or "lora_left_weight" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False
    return model


@pytest.mark.seq_inference
@pytest.mark.parametrize("batch_size", [1], ids=["bsz=1"])
112
@pytest.mark.parametrize("zero_stage", [2, 3], ids=["zero_stage=2", "zero_stage=3"])
113
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neo-125m", "facebook/opt-350m", "bigscience/bloom-560m"])
114
@pytest.mark.parametrize("offload_device", ["none", "cpu"])
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
class TestHybridEngineLoRA(DistributedTest):
    world_size = 1

    def get_model(self, model_name):
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        model_config = AutoConfig.from_pretrained(model_name)
        model_config.dropout = 0.0
        model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config)
        model = model.half()
        model = model.to(f'cuda:{local_rank}')
        return model

    def get_tokenizer(self, model_name):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.pad_token = tokenizer.eos_token
        return tokenizer

    def get_train_sentences(self, batch_size):
        sentences = [
            r"\n\nHuman: I am trying to write a fairy tale. What is the most popular plot?\n\n"
            r"Assistant: The most popular plot might be a princess goes to a faraway land, falls in love",
            r"\n\nHuman: What flowers should I grow to attract bees?\n\nAssistant: The reason you want bees "
            r"in your garden is to attract pollinators and get more fruit or vegetable production."
        ]
        if batch_size <= 2:
            return sentences[:batch_size]
        else:
            raise NotImplementedError(f"batch_size {batch_size} not implemented")

144
    def test_lora(self, batch_size, model_name, zero_stage, offload_device):
145 146 147 148 149 150 151 152 153
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        model = self.get_model(model_name)
        tokenizer = self.get_tokenizer(model_name)
        train_sentences = self.get_train_sentences(batch_size)

        # Inject LoRA
        model = convert_linear_layer_to_lora(model, "", 8)
        model = only_optimize_lora_parameters(model)

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
        ds_config = {
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1.0,
                    "betas": [0.9, 0.95]
                }
            },
            "train_batch_size": batch_size,
            "fp16": {
                "enabled": True,
                "initial_scale_power": 12
            },
            "hybrid_engine": {
                "enabled": True,
                "pin_parameters": True
            },
            "zero_optimization": {
                "stage": zero_stage,
                "offload_optimizer": {
                    "device": offload_device
                }
            }
        }

        model, *_ = deepspeed.initialize(model=model, config=ds_config)
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213

        # Verify gradient norm is larger than 0
        before_grad_update_layer0_params = [
            ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
            if ele is not None and len(ele.shape) > 1
        ]

        model.train()
        batch = tokenizer(train_sentences, max_length=16, padding="max_length", truncation=True, return_tensors="pt")
        batch = to_device(batch, f'cuda:{local_rank}')
        batch["labels"] = batch["input_ids"]
        outputs = model(**batch, use_cache=False)
        loss = outputs.loss
        model.backward(loss)

        grad_norm_dict = dict()
        for name, param in model.named_parameters():
            if param.requires_grad is True:
                grad_norm_dict[name] = torch.norm(safe_get_full_grad(param))

        model.step()
        grad_norm = sum([ele.detach().cpu().numpy() for ele in grad_norm_dict.values()])
        assert grad_norm > 1E-5

        # Verify parameter remains the same
        after_grad_update_layer0_params = [
            ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
            if ele is not None and len(ele.shape) > 1
        ]
        for lhs, rhs in zip(before_grad_update_layer0_params, after_grad_update_layer0_params):
            npt.assert_allclose(lhs, rhs, 1E-5, 1E-5)

        # Verify fuse will mutate layer_params
        model.eval()
214 215 216
        with GatheredParameters(model.parameters()):
            model.fuse_lora_weight()

217 218 219 220 221 222 223 224
        after_grad_update_layer0_params_lora_fused = [
            ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
            if ele is not None and len(ele.shape) > 1
        ]

        for lhs, rhs in zip(before_grad_update_layer0_params, after_grad_update_layer0_params_lora_fused):
            with pytest.raises(AssertionError):
                npt.assert_allclose(lhs, rhs, 1E-5, 1E-5)
225 226 227

        with GatheredParameters(model.parameters()):
            model.unfuse_lora_weight()