test_accelerator_backward.py 12.2 KB
Newer Older
1 2
'''Copyright The Microsoft DeepSpeed Team'''

J
Jeff Rasley 已提交
3 4 5 6 7 8 9
import numpy as np
import torch
import pytest
import random
import copy
from torch import nn
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
10
from deepspeed.accelerator import get_accelerator
11 12
from unit.modeling import BertConfig, BertLayerNorm, BertEncoder as BertEncoderPostln
from unit.modelingpreln import BertEncoder as BertEncoderPreln
13
from unit.common import DistributedTest
J
Jeff Rasley 已提交
14

15
#if not deepspeed.ops.__installed_ops__['transformer']:
16 17 18
#pytest.skip(
#    "transformer kernels are temporarily disabled because of unexplained failures",
#    allow_module_level=True)
19

J
Jeff Rasley 已提交
20 21 22 23 24

def check_equal(first, second, atol=1e-2, verbose=False):
    diction_x = {}
    diction_y = {}

25 26 27
    if verbose:
        for i, (x, y) in enumerate(zip(first, second)):
            print(x[1], y[1])
J
Jeff Rasley 已提交
28 29 30 31 32 33 34 35 36 37 38 39

    for i, (x, y) in enumerate(zip(first, second)):
        k = 0
        while (diction_x.get((k, x[1])) is not None):
            k = k + 1
        diction_x[k, x[1]] = x[0]
        k = 0
        while (diction_y.get((k, y[1])) is not None):
            k = k + 1
        diction_y[k, y[1]] = y[0]
    if verbose:
        print()
40 41
        for i, (x, y) in enumerate(zip(diction_x, diction_y)):
            print(x, y)
J
Jeff Rasley 已提交
42 43 44

    for i, (x, y) in enumerate(zip(diction_x, diction_y)):
        if (x[0] == 1): continue
45 46
        if verbose:
            print("checking ", x[1], ":")
J
Jeff Rasley 已提交
47 48
        y = diction_y[x[0], x[1]]
        x = diction_x[x[0], x[1]]
49 50 51 52

        if verbose:
            print(((x == float('inf')).nonzero(as_tuple=True)[0]))
            print(((y == float('inf')).nonzero(as_tuple=True)[0]))
J
Jeff Rasley 已提交
53 54 55 56 57 58 59 60
        x = x.cpu().detach().numpy()
        y = y.cpu().detach().numpy()

        avgx = np.sum(abs(x), dtype=float)
        countx = x.shape[0]
        for i in range(len(x.shape) - 1):
            countx *= x.shape[i + 1]
            avgx = np.sum(avgx)
A
Alex Hedges 已提交
61
        tolerance = 1
J
Jeff Rasley 已提交
62 63
        if avgx != float('inf') and avgx != -float('inf'):
            avgx = avgx / countx
A
Alex Hedges 已提交
64
            tolerance = avgx * atol
J
Jeff Rasley 已提交
65
        if verbose:
A
Alex Hedges 已提交
66
            print("tolerance is ", tolerance)
67 68 69 70 71 72 73 74 75 76
            x = x.flatten()
            y = y.flatten()
            print("x = {}".format(x))
            print("y = {}".format(y))
            if any(x == float('inf')) or any(x == -float('inf')):
                print("found infinity in x")
            if any(y == float('inf')) or any(y == -float('inf')):
                print("found infinity in y")
            print(np.linalg.norm(x.astype('float64')))
            print(np.linalg.norm(y.astype('float64')))
J
Jeff Rasley 已提交
77
            print('-' * 80)
78
        #toler = np.linalg.norm(x.astype('float64')) * 0.0005
A
Alex Hedges 已提交
79
        np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i), atol=tolerance)
J
Jeff Rasley 已提交
80 81 82 83 84 85 86


def zero_grad(variables):
    for variable in variables:
        variable.grad.zero_()


87
device = torch.device(get_accelerator().device_name())
J
Jeff Rasley 已提交
88 89 90 91 92
kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True}
kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True}


class DSEncoder(nn.Module):
93

J
Jeff Rasley 已提交
94 95 96 97
    def __init__(self, config, weights, biases):
        super(DSEncoder, self).__init__()
        self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.layer = nn.ModuleList([
98
            copy.deepcopy(DeepSpeedTransformerLayer(config, weights, biases)) for _ in range(config.num_hidden_layers)
J
Jeff Rasley 已提交
99 100 101 102
        ])
        self.grads = []
        self.pre_or_post = config.pre_layer_norm

103
    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False):
J
Jeff Rasley 已提交
104 105 106
        all_encoder_layers = []

        def custom(start, end):
107

J
Jeff Rasley 已提交
108 109 110 111 112 113 114 115 116 117
            def custom_forward(*inputs):
                layers = self.layer[start:end]
                x_ = inputs[0]
                for layer in layers:
                    x_ = layer(x_, inputs[1])
                return x_

            return custom_forward

        if checkpoint_activations:
118 119 120 121 122 123 124 125 126 127 128 129
            raise NotImplementedError("`checkpoint` is not defined below")
            #l = 0
            #num_layers = len(self.layer)
            #chunk_length = math.ceil(math.sqrt(num_layers))
            #while l < num_layers:
            #    hidden_states = checkpoint.checkpoint(
            #        custom(
            #            l,  # noqa: F821
            #            l + chunk_length),
            #        hidden_states,
            #        attention_mask * 1)
            #    l += chunk_length
J
Jeff Rasley 已提交
130 131 132
            # decoder layers
        else:
            for i, layer_module in enumerate(self.layer):
133 134
                hidden_states = layer_module(hidden_states, attention_mask, grads=self.grads)
                hidden_states.register_hook(lambda x, self=self: self.grads.append([x, "hidden_state"]))
J
Jeff Rasley 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153

                if output_all_encoded_layers:
                    all_encoder_layers.append(hidden_states)

        if not output_all_encoded_layers or checkpoint_activations:
            if (self.pre_or_post):
                hidden_states = self.FinalLayerNorm(hidden_states)
            all_encoder_layers.append(hidden_states)
        return all_encoder_layers

    def get_grads(self):
        return self.grads


def create_models(ds_config):
    bert_config = BertConfig(vocab_size_or_config_json_file=119547,
                             hidden_size=ds_config.hidden_size,
                             num_hidden_layers=ds_config.num_hidden_layers,
                             num_attention_heads=ds_config.heads,
154
                             intermediate_size=ds_config.intermediate_size,
J
Jeff Rasley 已提交
155 156 157
                             hidden_act="gelu",
                             hidden_dropout_prob=ds_config.hidden_dropout_ratio,
                             attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
158
                             max_position_embeddings=512,
J
Jeff Rasley 已提交
159 160 161 162 163 164 165
                             type_vocab_size=2,
                             initializer_range=ds_config.initializer_range)

    weights = []
    biases = []

    for i in range(4):
166
        weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size, ds_config.hidden_size)))
J
Jeff Rasley 已提交
167 168 169 170
        weights[i].data.normal_(mean=0.0, std=ds_config.initializer_range)

    weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
    weights[4].data.fill_(1.0)
171
    weights.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size, ds_config.hidden_size)))
J
Jeff Rasley 已提交
172
    weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range)
173
    weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size, ds_config.intermediate_size)))
J
Jeff Rasley 已提交
174 175 176 177 178 179 180 181 182
    weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range)
    weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
    weights[7].data.fill_(1.0)

    biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
    biases[0].data.zero_()
    for i in range(4):
        biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
        biases[i + 1].data.zero_()
183
    biases.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size)))
J
Jeff Rasley 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    biases[5].data.zero_()
    biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
    biases[6].data.zero_()
    biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
    biases[7].data.zero_()

    if (ds_config.pre_layer_norm):
        bert_encoder = BertEncoderPreln(bert_config, weights, biases)
    else:
        bert_encoder = BertEncoderPostln(bert_config, weights, biases)
    ds_encoder = DSEncoder(ds_config, weights, biases)

    if ds_config.fp16:
        bert_encoder.half()
        ds_encoder.half()

200 201
    bert_encoder.to(get_accelerator().device_name())
    ds_encoder.to(get_accelerator().device_name())
J
Jeff Rasley 已提交
202 203 204 205 206 207 208 209 210 211

    return bert_encoder, ds_encoder


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


212
def run_backward(ds_config, seq_len, atol=1e-2, verbose=False):
J
Jeff Rasley 已提交
213 214 215 216 217
    set_seed(123)
    bert_encoder, ds_encoder = create_models(ds_config)

    # prepare test data
    kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
218
    hidden_states = torch.randn(ds_config.batch_size, seq_len, ds_config.hidden_size, **kwargs)
219 220
    input_mask = torch.randn(ds_config.batch_size, 1, 1, seq_len, **kwargs)
    Y = torch.randn(ds_config.batch_size, seq_len, ds_config.hidden_size, **kwargs)
J
Jeff Rasley 已提交
221 222 223 224 225 226 227

    # run baseline
    base_results = bert_encoder(hidden_states,
                                input_mask,
                                output_all_encoded_layers=False,
                                checkpoint_activations=False)

228
    loss = (Y - base_results[0]).pow(2).sum() / 64
J
Jeff Rasley 已提交
229 230 231 232
    loss.backward()
    base_grads = bert_encoder.get_grads()

    # run ds
233
    ds_results = ds_encoder(hidden_states, input_mask, output_all_encoded_layers=False, checkpoint_activations=False)
J
Jeff Rasley 已提交
234

235
    loss = (Y - ds_results[0]).pow(2).sum() / 64
J
Jeff Rasley 已提交
236 237 238 239 240 241 242
    loss.backward()
    ds_grads = ds_encoder.get_grads()

    # check grads
    check_equal(base_grads, ds_grads, atol=atol, verbose=verbose)


243
#test_backward[3-1024-120-16-24-True-True-0.05]
244 245
#test_backward[3-1024-52-16-24-False-True-0.2]
# 3-128-54-2-24-False-True-0.2
J
Jeff Rasley 已提交
246 247
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
                         [
248
                             (64,160,128,2,24,False,True, 0.2),
249
                             (64,1600,128,2,4,False,True, 0.2),
250 251 252
                             (8,1600,128,25,3,True,True, 0.05),
                             (8,160,128,2,3,True,True, 0.1),
                             (8,1600,128,2,3,True,True, 0.05),
253 254
                             #(3,1024,119,16,24,True,False, 0.05),
                             #(3,1024,115,16,24,True,True, 0.05),
255
                             #(1024,128,10,2,2,False,False, 0.1),
256 257 258
                             #(3,1024,52,16,24,False,True, 0.2),
                             #(3,128,51,2,24,False,False, 0.1),
                             #(3,128,54,2,24,False,True, 0.2),
J
Jeff Rasley 已提交
259
                         ]) # yapf: disable
260 261 262
class TestCUDABackward(DistributedTest):
    world_size = 1

263
    def test_backward(self, batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol):
264
        # Only run fp16 test cases on devices with FP16 capability.
265
        if not get_accelerator().is_fp16_supported() and (use_fp16 is True or is_preln is False):
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
            return

        ds_config = DeepSpeedTransformerConfig()
        ds_config.layer_id = None
        ds_config.batch_size = batch_size
        ds_config.hidden_size = hidden_size
        ds_config.intermediate_size = hidden_size
        ds_config.heads = heads
        ds_config.attn_dropout_ratio = 0.0
        ds_config.hidden_dropout_ratio = 0.0
        ds_config.num_hidden_layers = num_layers
        ds_config.pre_layer_norm = is_preln
        ds_config.initializer_range = 0.02
        ds_config.fp16 = use_fp16

        run_backward(ds_config, seq_len, atol=atol, verbose=True)

    #                         [
    #                             (3,1024,128,16,24,True,False, 0.07),
    #                             (3,1024,128,16,24,True,True, 0.05),
    #                             (3,1024,128,16,24,False,False, 0.1),
    #                             (3,1024,128,16,24,False,True, 0.2),
    #                         ]) # yapf: disable
    #def test_backward_stochastic(batch_size,
    #                             hidden_size,
    #                             seq_len,
    #                             heads,
    #                             num_layers,
    #                             is_preln,
    #                             use_fp16,
    #                             atol):
297 298
    #    # Only run fp16 test cases on devices with FP16 capability.
    #    if not get_accelerator().is_fp16_supported() and use_fp16 is True:
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
    #        return
    #
    #    ds_config = DeepSpeedTransformerConfig()
    #    ds_config.layer_id = None
    #    ds_config.batch_size = batch_size
    #    ds_config.hidden_size = hidden_size
    #    ds_config.intermediate_size = 4 * hidden_size
    #    ds_config.max_seq_length = seq_len
    #    ds_config.heads = heads
    #    ds_config.attn_dropout_ratio = 0.0
    #    ds_config.hidden_dropout_ratio = 0.0
    #    ds_config.num_hidden_layers = num_layers
    #    ds_config.pre_layer_norm = is_preln
    #    ds_config.initializer_range = 0.02
    #    ds_config.fp16 = use_fp16
    #    ds_config.stochastic_mode = True
    #
    #    run_backward(ds_config, atol=atol)