diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 08cf1fa9bacf068b3f08765c83eac18cdc17c513..cff69b722f3dfa5c2fe0cb636f2aa387898c4886 100644 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -44,15 +44,6 @@ class InferenceEngine(Module): self.quantize_merge_count = 1 self.quantization_scales = None - if self.mpu: - self.mp_world_size = dist.get_world_size( - group=self.mpu.get_model_parallel_group()) - self.mp_group = self.mpu.get_model_parallel_group() - elif self.mp_world_size > 1 and not dist.is_initialized(): - self._create_model_parallel_group() - else: - self.module.to(torch.cuda.current_device()) - self._check_quantize_setting(quantization_setting) if self.checkpoint: @@ -62,6 +53,13 @@ class InferenceEngine(Module): if self.dtype: self._convert_to_dtype() + if self.mpu: + self.mp_world_size = dist.get_world_size( + group=self.mpu.get_model_parallel_group()) + self.mp_group = self.mpu.get_model_parallel_group() + elif self.mp_world_size > 1 and not dist.is_initialized(): + self._create_model_parallel_group() + # apply injection policy if self.injection_dict: for client_module, injection_policy in self.injection_dict.items(): @@ -69,6 +67,8 @@ class InferenceEngine(Module): elif replace_method == 'auto': self._apply_injection_policy() + self.module.to(torch.cuda.current_device()) + if self.mp_world_size > 1: self.model_orig_fwd = self.module.forward self.module.forward = self.forward @@ -96,11 +96,6 @@ class InferenceEngine(Module): ranks = [i for i in range(self.mp_world_size)] self.mp_group = dist.new_group(ranks) - self.module.to(torch.cuda.current_device()) - for p in self.module.parameters(): - if torch.is_tensor(p): - dist.broadcast(p, 0) - def _check_quantize_setting(self, quantization_setting): self.quatize_bits = 8 self.mlp_extra_grouping = False diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 362be101104dcffdac787cb3c9bb705cffd93b0b..09eb90a584f3a51391bceb68e9ab7cb2315072ca 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -137,6 +137,9 @@ def replace_transformer_layer(orig_layer_impl, if inference: hidden_size, num_attention_heads = policy.get_hidden_heads() + assert num_attention_heads % mp_size == 0,\ + "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\ + "This is because the attention computation is partitioned evenly among the parallel GPUs." attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention = policy.attention() mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()