replace_module.py 28.0 KB
Newer Older
1 2 3 4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
5

6
import os
J
Jeff Rasley 已提交
7
import torch
8
import tqdm
J
Jeff Rasley 已提交
9
import deepspeed
10
import deepspeed.ops.transformer as transformer_inference
11 12 13
from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
14
from deepspeed.accelerator import get_accelerator
J
Jeff Rasley 已提交
15
from .replace_policy import replace_policies, generic_policies
16 17
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading

18
from deepspeed import comm as dist
19

20
from .load_checkpoint import load_model_with_checkpoint
21
import time
22

23
from .utils import policy_to_ds_container
24 25
import gc

26

27
def get_transformer_name(replaced_module):
28
    from .containers import supported_models
29 30 31 32 33 34 35 36 37 38 39 40 41
    from torch.nn import ModuleList
    transformer_name = ''
    for n, c in replaced_module.named_children():
        if c.__class__ in supported_models:
            transformer_name += n + '.'
            for name, child in c.named_children():
                if child.__class__ is ModuleList:
                    transformer_name += name
                    break
            break
    return transformer_name


42
class GroupQuantizer:
43

44
    def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
45 46 47
        self.group_size = group_size
        self.num_bits = num_bits
        self.q_int8 = q_int8
48 49

        self.num_groups = num_groups
50 51 52 53 54 55 56

    def quantize(self, inputs, qkv=True, count=1, parallel_dim=0):
        if not self.q_int8 or not qkv:
            inputs = torch.nn.Parameter(inputs, requires_grad=False)
            inputs.scale = torch.empty(1)
            return inputs
        q_range = 2**self.num_bits
57
        num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[0] // self.group_size
58
        inputs = inputs.to(get_accelerator().current_device_name())
59
        input_flat = inputs.reshape(num_groups, -1).contiguous()
60 61 62 63 64 65 66
        input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
        input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float()
        scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range)
        input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1)
        inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
        out = torch.nn.Parameter(inputs_q, requires_grad=False)
        inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
67 68 69 70 71 72 73 74
        input_flat = [inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2)]
        input_min = [torch.min(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
        input_max = [torch.max(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
        scale1 = [(torch.max(input_min[i].abs(), input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
                  for i in range(2)]

        out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], dim=0).reshape(num_groups,
                                                                                                   -1).contiguous()
75 76 77
        return out


J
Jeff Rasley 已提交
78 79 80 81 82 83 84 85
def _module_match(module):
    for policy in generic_policies:
        policy = policy()
        if policy.match(module):
            return policy
    return None


86
def generic_injection(module, dtype=None, enable_cuda_graph=True):
87

88
    def replace_attn(child, policy):
J
Jeff Rasley 已提交
89 90 91 92 93 94 95 96 97 98 99
        policy_attn = policy.attention(child)
        if policy_attn is None:
            return child
        if len(policy_attn) == 5:
            qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn
        else:
            qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn

        config = transformer_inference.DeepSpeedInferenceConfig(
            hidden_size=hidden_size,
            heads=heads,
100
            dtype=dtype,
J
Jeff Rasley 已提交
101
            triangular_masking=False,
102
            max_out_tokens=4096,
J
Jeff Rasley 已提交
103
        )
104
        attn_module = DeepSpeedDiffusersAttention(config)
J
Jeff Rasley 已提交
105 106

        def transpose(data):
107
            data = data.contiguous()
J
Jeff Rasley 已提交
108 109
            data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
            data = data.reshape(data.shape[-1], data.shape[-2])
110
            data.to(get_accelerator().current_device_name())
J
Jeff Rasley 已提交
111 112 113 114 115 116 117 118 119 120 121 122
            return data

        if len(policy_attn) == 5:
            attn_module.attn_qkvw.data = transpose(qkvw.data)
        else:
            attn_module.attn_qkvw = None
            attn_module.attn_qw.data = transpose(qw.data)
            attn_module.attn_kw.data = transpose(kw.data)
            attn_module.attn_vw.data = transpose(vw.data)

        attn_module.attn_qkvb = None
        attn_module.attn_ow.data = transpose(attn_ow.data)
123
        attn_module.attn_ob.data.copy_(attn_ob.data.to(get_accelerator().current_device_name()))
J
Jeff Rasley 已提交
124 125
        return attn_module

126
    def replace_attn_block(child, policy):
127 128
        config = Diffusers2DTransformerConfig()
        return DeepSpeedDiffusersTransformerBlock(child, config)
129

J
Jeff Rasley 已提交
130 131 132
    if isinstance(module, torch.nn.Module):
        pass
    else:
133 134
        if dtype not in [torch.float16, torch.half]:
            raise ValueError("Generic injection only supported with FP16")
135

J
Jeff Rasley 已提交
136 137
        try:
            import diffusers
M
Molly Smith 已提交
138 139 140 141
            if hasattr(diffusers.models.attention, 'CrossAttention'):
                cross_attention = diffusers.models.attention.CrossAttention
            else:
                cross_attention = diffusers.models.attention_processor.Attention
142 143 144 145 146
            attention_block = diffusers.models.attention.BasicTransformerBlock
            new_policies = {
                cross_attention: replace_attn,
                attention_block: replace_attn_block,
            }
J
Jeff Rasley 已提交
147 148 149 150 151 152 153
        except ImportError:
            new_policies = {}

        #replace_transformer_layer(None,
        #                          module.text_encoder,
        #                          training=False,
        #                          replace_with_kernel_inject=True,
154 155 156
        #                          triangular_masking=True,
        #                          max_out_tokens=8192)
        from ..model_implementations.transformers.clip_encoder import DSClipEncoder
157
        cg_encoder = DSClipEncoder(module.text_encoder, enable_cuda_graph=enable_cuda_graph)
158
        setattr(module, 'text_encoder', cg_encoder)
J
Jeff Rasley 已提交
159 160 161 162 163 164
        for name in module.__dict__.keys():
            sub_module = getattr(module, name)
            policy = _module_match(sub_module)

            if policy is not None:

165
                def _replace_module(module, policy):
J
Jeff Rasley 已提交
166
                    for name, child in module.named_children():
167
                        _replace_module(child, policy)
J
Jeff Rasley 已提交
168
                        if child.__class__ in new_policies:
169
                            replaced_module = new_policies[child.__class__](child, policy)
J
Jeff Rasley 已提交
170 171 172
                            setattr(module, name, replaced_module)

                _replace_module(sub_module, policy)
173
                new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
J
Jeff Rasley 已提交
174 175 176 177
                print(f"**** found and replaced {name} w. {type(new_module)}")
                setattr(module, name, new_module)


178
container_g = None
179 180


181
def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, model_config):
J
Jeff Rasley 已提交
182 183 184 185 186
    """ Replace bert-style transformer layers with DeepSpeed's transformer layer
    Arguments:
        orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
            e.g., transformers.modeling_bert.BertLayer.
        model (torch.nn.Module): user's nn.module representing their model
187 188 189
        checkpoint_dict: Dictionary for checkpoint passed from the Inference Engine
        config: top-level DS Inference config defined in inference/config.py
        model_config: HuggingFace model config passed from the inference/engine.py
J
Jeff Rasley 已提交
190 191 192
    Returns:
        Updated nn.module with replaced transformer layers
    """
193 194 195 196 197 198 199 200 201 202 203 204
    # defining globals as internally defined functions inherit these everywhere
    quantize = (config.dtype == torch.int8)
    # todo: Refactor later. In future, let's minimize the style used above and use config.** instead

    linear_layer_setting = None
    '''
        linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers and embedding layers
    '''
    micro_batch_size = -1
    seed = -1
    local_rank = -1

205 206
    mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group,
                                          mp_size=config.tensor_parallel.tp_size)  #, out_dim=0, in_dim=1)
207

208
    def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0):
209
        policy = policy_cls(child, inference=inference)
210 211
        if not policy.cuda_graph_supported:
            # policy says cuda graph is not supported raise an error if set
212
            assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable"
213

214 215
        from deepspeed.moe.layer import MoE
        moe = False
216
        if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
217 218
            num_experts = child.mlp.num_experts
            moe = True
219

220 221 222 223 224 225 226 227 228
        # 1. Create a model-specific container object using the policy object.
        _container = policy_to_ds_container(policy=policy,
                                            config=config,
                                            model_config=model_config,
                                            layer_id=layer_id,
                                            child=child)
        _container.set_moe(moe)

        # 2. Set the tensor parallelism config
229
        _container.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
230 231 232

        # 3. Initialize tensors
        _container.initialize_tensors()
233

234
        # 4. deal with data types -- needs refactor to use dtype instead of fp16
235 236
        if config.dtype in [torch.float16, torch.bfloat16, torch.int8]:
            _container.convert_to_required_dtype()
237 238

        # 5. Set the quantization config
239
        quantizer = GroupQuantizer(q_int8=quantize)
240
        _container.set_quantization_config(quantizer)
241

242
        # 6. create a DS Inference config object
243
        _container.create_ds_model_config()
244

245 246 247 248 249 250 251 252 253 254 255 256
        # 7. use the config and create the module
        _container.create_module()

        # 8. transpose the weights and bias if needed
        _container.transpose()

        # 9. deal with tensor parallelism.
        _container.apply_tensor_parallelism(mp_replace)

        # 10. copy the tensors from the model-specific container to the new module
        _container.copy_data_to_new_module()

257
        # 11. set global for generic checkpoint loading
258
        global container_g
259

260 261
        if container_g is None:
            container_g = _container
262 263

        return _container.module
264

265
    def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
266 267 268 269 270 271 272 273 274
        #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)

        # 1. Create AutoTP object
        _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)

        # 2. Set the tensor parallelism config
        _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)

        # 3. Set linear policies
275
        _autotp.update_linear_policies()
276

277 278
        # 4. Replace modules
        return _autotp._replace_module(module)
279

280
    def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
281
        training = False  # todo: refactor this part to go in the config
282 283
        if training:
            # copy relevant state from child -> new module
284
            new_module = replace_with_policy(child, _policy, config.triangular_masking)
285

J
Jeff Rasley 已提交
286
        else:
287
            # copy relevant state from child -> new module
288
            if config.replace_with_kernel_inject:
A
Alex Hedges 已提交
289 290
                new_module = replace_with_policy(child,
                                                 _policy,
291
                                                 config.triangular_masking,
A
Alex Hedges 已提交
292 293
                                                 inference=True,
                                                 layer_id=layer_id)
294
            else:
295
                new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
296

J
Jeff Rasley 已提交
297 298
        return new_module

D
digger yu 已提交
299
    if checkpoint_dict is not None and not config.replace_with_kernel_inject:
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
        # AutoTP shard loading
        checkpoint = checkpoint_dict["checkpoints"]
        pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
        for i in range(len(checkpoint)):
            replaced_module = replace_module(model=model,
                                             orig_class=orig_layer_impl,
                                             replace_fn=replace_fn,
                                             _replace_policy=config.injection_policy_tuple,
                                             checkpoint=checkpoint[i])
            pbar.update(1)
            gc.collect()
    else:
        replaced_module = replace_module(model=model,
                                         orig_class=orig_layer_impl,
                                         replace_fn=replace_fn,
                                         _replace_policy=config.injection_policy_tuple)
316

317
    quantizer = GroupQuantizer(q_int8=quantize)
318 319
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if dist.is_initialized() else 0
320
    if checkpoint_dict is not None and config.replace_with_kernel_inject:
321 322
        assert container_g.ckpt_load_enabled, \
               f"Meta Tensor checkpoint loading not supported in {container_g.__class__.__name__} container"
323 324
        start_time = time.time()
        checkpoint = checkpoint_dict['checkpoints']
325
        ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint
326
        ckpt_type = checkpoint_dict.get('parallelization', 'pp')
327 328
        ckpt_mp_size = checkpoint_dict.get('tp_size', len(ckpt_list))
        ckpt_mp_size = checkpoint_dict.get('mp_size', ckpt_mp_size)
329
        base_dir1 = checkpoint_dict.get('base_dir', config.base_dir)
330

331
        if ckpt_type == 'pp' and type(checkpoint) is list:
332
            pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
333

334
            for i in range(len(checkpoint)):
335
                sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
336 337 338 339 340 341
                load_model_with_checkpoint(replaced_module,
                                           sd,
                                           mp_replace,
                                           ckpt_type,
                                           ckpt_mp_size,
                                           quantizer,
342
                                           container=container_g)
343
                pbar.update(1)
344
        else:
345 346 347 348
            num_checkpoints = len(ckpt_list) // ckpt_mp_size
            tp_split_size = (world_size / ckpt_mp_size)
            sd_offset = int(rank / tp_split_size)
            sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
349
            pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards")
350
            for i in range(num_checkpoints):
351 352 353
                pbar.update(1)
                ckpt_index = i * ckpt_mp_size + sd_offset
                ckpt_files = [
354
                    os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
355 356
                    for j in range(sd_count)
                ]
357
                sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
358 359 360 361 362 363 364
                load_model_with_checkpoint(replaced_module,
                                           sds,
                                           mp_replace,
                                           ckpt_type,
                                           ckpt_mp_size,
                                           quantizer,
                                           int(rank % tp_split_size),
365
                                           container=container_g)
366 367 368 369
                sds = [None for _ in sds]
                gc.collect()

            if "non_tp" in checkpoint:
370 371
                pbar = tqdm.tqdm(total=len(checkpoint["non_tp"]),
                                 desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
372 373 374 375

                for i in range(len(checkpoint["non_tp"])):
                    pbar.update(1)
                    ckpt_file = os.path.join(base_dir1,
376
                                             checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
377
                    sds = [torch.load(ckpt_file, map_location='cpu')]
378 379 380 381 382 383 384
                    load_model_with_checkpoint(replaced_module,
                                               sds,
                                               mp_replace,
                                               ckpt_type,
                                               ckpt_mp_size,
                                               quantizer,
                                               int(rank % tp_split_size),
385
                                               container=container_g)
386 387
                    sds = [None for _ in sds]
                    gc.collect()
388 389
        print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")

390
    if config.save_mp_checkpoint_path is not None:
391 392
        from collections import OrderedDict
        import json
393
        num_partitions = 8
394

395 396 397 398 399 400 401 402 403 404
        if checkpoint_dict is None:
            ckpt_name = "ds_model"
            try:
                from transformers.models.bloom.modeling_bloom import BloomForCausalLM
                if isinstance(model, BloomForCausalLM):
                    ckpt_name = "bloom"
            except ImportError:
                ckpt_name = "ds_model"
        else:
            ckpt_name = checkpoint_dict['type']
405 406 407
        if dist.is_initialized():
            dist.barrier()
        transformer_name = get_transformer_name(replaced_module)
408 409
        non_tp_ckpt_name = f'non-tp.pt'
        ckpt_files = [non_tp_ckpt_name]
410
        os.makedirs(config.save_mp_checkpoint_path, exist_ok=True)
411

412 413 414
        if not dist.is_initialized() or dist.get_rank() == 0:
            print("Saving tp-sharded checkpoints")
            torch.save(
415 416 417
                OrderedDict({k: v
                             for k, v in dict(replaced_module.state_dict()).items()
                             if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
418 419 420 421 422 423 424 425

            dtype_reprs = {
                torch.float32: 'float32',
                torch.float16: 'float16',
                torch.int8: 'int8',
                torch.bfloat16: 'bfloat16'
            }

426
            ckpt_config = json.dumps({
427 428
                'type': ckpt_name,
                'base_dir': f'{config.save_mp_checkpoint_path}',
429
                'checkpoints': {
430 431
                    "non_tp": ckpt_files,
                    "tp": [f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) for r in range(world_size)]
432
                },
433 434 435
                'version': 1.0,
                'parallelization': 'tp',
                'tp_size': world_size,
436
                'dtype': dtype_reprs[config.dtype]
437
            })
438
            with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", "w") as cfg:
439
                cfg.write(ckpt_config)
440 441 442 443 444 445 446 447 448 449

        rep_sd = replaced_module.state_dict()
        for n, p in replaced_module.named_parameters():
            if hasattr(p, 'scale'):
                rep_sd[n] = [p, p.scale]
        keys = list(rep_sd.keys())
        partition_size = (len(keys) // num_partitions + 1)
        for m in range(num_partitions):
            torch.save(
                OrderedDict({
450 451 452
                    k: [rep_sd[k], rep_sd[k].scale] if hasattr(rep_sd[k], 'scale') else rep_sd[k]
                    for k in keys[m * partition_size:(m + 1) * partition_size] if transformer_name in k
                }), f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
453

454
    return replaced_module
J
Jeff Rasley 已提交
455 456


457
def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
J
Jeff Rasley 已提交
458 459 460 461 462
    """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer
    Arguments:
        orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
            e.g., transformers.modeling_bert.BertLayer.
        model (torch.nn.Module): user's nn.module representing their model
463
        config (dict): model config containing hidden size, attention heads, etc.
J
Jeff Rasley 已提交
464 465 466
    Returns:
        Updated nn.module with original bert-style transformer layers
    """
467

468
    def replace_fn(child, _replace_policy, layer_id):
J
Jeff Rasley 已提交
469
        #from turing.nvidia_modelingpreln import BertLayer
470
        orig_module = orig_layer_impl(config)
J
Jeff Rasley 已提交
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521

        # copy relevant state from child -> original module
        qkvw = child.attn_qkvw.data
        qkvb = child.attn_qkvb.data

        qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
        qb, kb, vb = torch.chunk(qkvb, 3, axis=0)

        orig_module.attention.self.query.weight.data = qw
        orig_module.attention.self.query.bias.data = qb
        orig_module.attention.self.key.weight.data = kw
        orig_module.attention.self.key.bias.data = kb
        orig_module.attention.self.value.weight.data = vw
        orig_module.attention.self.value.bias.data = vb

        orig_module.attention.output.dense.weight.data = child.attn_ow.data
        orig_module.attention.output.dense.bias.data = child.attn_ob.data

        attn_ln_w = child.attn_nw.data
        attn_ln_b = child.attn_nb.data
        if preln:
            orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
            orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
        else:
            orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
            orig_module.attention.output.LayerNorm.bias.data = attn_ln_b

        inter_ff_w = child.inter_w.data
        inter_ff_b = child.inter_b.data
        if preln:
            orig_module.intermediate.dense_act.weight.data = inter_ff_w
            orig_module.intermediate.dense_act.bias.data = inter_ff_b
        else:
            orig_module.intermediate.dense.weight.data = inter_ff_w
            orig_module.intermediate.dense.bias.data = inter_ff_b

        orig_module.output.dense.weight.data = child.output_w.data
        orig_module.output.dense.bias.data = child.output_b.data

        transformer_ln_w = child.norm_w.data
        transformer_ln_b = child.norm_b.data
        if preln:
            orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
            orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
        else:
            orig_module.output.LayerNorm.weight.data = transformer_ln_w
            orig_module.output.LayerNorm.bias.data = transformer_ln_b
        return orig_module

    return replace_module(model=model,
                          orig_class=deepspeed.DeepSpeedTransformerLayer,
522 523
                          replace_fn=replace_fn,
                          _replace_policy=None)
J
Jeff Rasley 已提交
524 525


526
def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=None):
J
Jeff Rasley 已提交
527 528 529 530 531 532 533 534 535
    """ Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
    Arguments:
        model (torch.nn.Module): the model to augment
        orig_class (torch.nn.Module): the module to search for
        replace_fn (method): a method to convert instances of ``orig_class`` to the
                             desired type and return a new instance.
    Returns:
        A modified ``model``.
    """
536
    sd = None
D
digger yu 已提交
537
    if checkpoint is not None:
538
        sd = torch.load(checkpoint, map_location='cpu')
539 540 541 542 543
    policy = {}
    if orig_class is not None:
        policy.update({orig_class: (replace_fn, _replace_policy)})
    else:
        for plcy in replace_policies:
544 545
            # instantiate a throw-away policy in order to populate the _orig_layer_class
            _ = plcy(None)
546 547 548 549
            if isinstance(plcy._orig_layer_class, list):
                for orig_layer_class in plcy._orig_layer_class:
                    policy.update({orig_layer_class: (replace_fn, plcy)})
            elif plcy._orig_layer_class is not None:
550 551
                policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
    assert len(policy.items()) > 0,\
A
Alex Hedges 已提交
552
        "No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
553
        "You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
554

555
    replaced_module, _ = _replace_module(model, policy, state_dict=sd)
D
digger yu 已提交
556
    if checkpoint is not None:
557 558 559 560 561 562 563
        embedding_weight = None
        for n, p in replaced_module.named_parameters():
            if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
                embedding_weight = p
        if embedding_weight is not None and hasattr(replaced_module, "lm_head") and hasattr(
                replaced_module.lm_head, "weight") and replaced_module.lm_head.weight.is_meta:
            replaced_module.lm_head.weight = embedding_weight
564
    return replaced_module
J
Jeff Rasley 已提交
565 566


567 568
from ..pipe import PipelineModule

569 570 571
import re


572
def skip_level_0_prefix(model, state_dict):
573 574 575 576 577 578
    model = str(model)
    key = re.search(r": (.*?)Model", model)
    if key is None:
        key = re.search(r": (.*?)Stack", model)
    if key is None:
        key = re.match(r"(.*?)Model", model)
579 580 581 582 583 584 585
    # if keys start with 'model.', don't skip level 0 prefix
    if state_dict != None:
        for item in state_dict.keys():
            if re.match("^model[.]", item):
                return False
    if key is not None and key.group(1).lower() in ["bloom", "opt"]:
        return True
586 587 588 589
    return False


def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_dict=None):
J
Jeff Rasley 已提交
590 591 592 593 594 595 596 597 598
    """ Traverse model's children recursively and apply any transformations in ``policies``.
    Arguments:
        model (torch.nn.Module): model to augment
        policies (dict): Mapping of source class to replacement function.
    Returns:
        Modified ``model``.
    """
    for name, child in model.named_children():
        if child.__class__ in policies:
599 600 601 602 603
            replaced_module = policies[child.__class__][0](child,
                                                           policies[child.__class__][-1],
                                                           layer_id,
                                                           prefix=prefix + name,
                                                           state_dict=state_dict)
604 605 606 607 608
            setattr(model, name, replaced_module)
            if isinstance(model, PipelineModule):
                assert hasattr(model, 'forward_funcs'),\
                    "we require pipe-module to have the list of fwd_functions"
                model.forward_funcs[model.fwd_map[name]] = replaced_module
609
            layer_id += 1
J
Jeff Rasley 已提交
610
        else:
611
            checking_key = prefix + name + '.'
612
            if Loading.is_load_module(child) and state_dict is not None:
613
                if any(checking_key in item for item in state_dict):
614
                    Loading.load(
615 616 617 618 619 620
                        child,
                        state_dict,
                        checking_key,
                    )
                else:
                    continue
D
digger yu 已提交
621
            if len(child._buffers) != 0 and state_dict is not None:
622
                Loading.load_buffer(child, state_dict, checking_key)
623 624
            _, layer_id = _replace_module(child,
                                          policies,
625
                                          prefix if level_id == 0 and skip_level_0_prefix(model, state_dict) else \
626 627 628 629
                                          prefix + name + '.',
                                          layer_id=layer_id,
                                          level_id=level_id + 1,
                                          state_dict=state_dict)
J
Jeff Rasley 已提交
630

631 632
    # Add the reset_cache func to the model, so that it can be called in the beginning of text-generation.
    model.reset_cache = transformer_inference.DeepSpeedTransformerInference.reset_cache
633
    return model, layer_id