未验证 提交 49b6a632 编写于 作者: R Reza Yazdani 提交者: GitHub

Reducing the memory-overhead of creating model for multi-GPU run (#1244)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 274c375c
......@@ -44,15 +44,6 @@ class InferenceEngine(Module):
self.quantize_merge_count = 1
self.quantization_scales = None
if self.mpu:
self.mp_world_size = dist.get_world_size(
group=self.mpu.get_model_parallel_group())
self.mp_group = self.mpu.get_model_parallel_group()
elif self.mp_world_size > 1 and not dist.is_initialized():
self._create_model_parallel_group()
else:
self.module.to(torch.cuda.current_device())
self._check_quantize_setting(quantization_setting)
if self.checkpoint:
......@@ -62,6 +53,13 @@ class InferenceEngine(Module):
if self.dtype:
self._convert_to_dtype()
if self.mpu:
self.mp_world_size = dist.get_world_size(
group=self.mpu.get_model_parallel_group())
self.mp_group = self.mpu.get_model_parallel_group()
elif self.mp_world_size > 1 and not dist.is_initialized():
self._create_model_parallel_group()
# apply injection policy
if self.injection_dict:
for client_module, injection_policy in self.injection_dict.items():
......@@ -69,6 +67,8 @@ class InferenceEngine(Module):
elif replace_method == 'auto':
self._apply_injection_policy()
self.module.to(torch.cuda.current_device())
if self.mp_world_size > 1:
self.model_orig_fwd = self.module.forward
self.module.forward = self.forward
......@@ -96,11 +96,6 @@ class InferenceEngine(Module):
ranks = [i for i in range(self.mp_world_size)]
self.mp_group = dist.new_group(ranks)
self.module.to(torch.cuda.current_device())
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, 0)
def _check_quantize_setting(self, quantization_setting):
self.quatize_bits = 8
self.mlp_extra_grouping = False
......
......@@ -137,6 +137,9 @@ def replace_transformer_layer(orig_layer_impl,
if inference:
hidden_size, num_attention_heads = policy.get_hidden_heads()
assert num_attention_heads % mp_size == 0,\
"To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
"This is because the attention computation is partitioned evenly among the parallel GPUs."
attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention = policy.attention()
mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册