未验证 提交 46784cb5 编写于 作者: M Molly Smith 提交者: GitHub

Fix auto TP for duplicate modules with different gems (#2784)

* Fix auto TP for duplicate modules with different gems

* precommit and comments

* Comment

* Combine gem list of same named modules

* remove duplicates from gem_list before updating policy

* Add module attribute with name variation for ProphetNet

---------
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 cc1054d9
......@@ -41,12 +41,24 @@ class AutoTP():
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + [parent + "." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
elif isinstance(submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
return layer_list
def update_policy_list(policy_list, new_module, new_gems):
if len(policy_list):
for i, policy in enumerate(policy_list):
# if module already exists in policy, combine gems and remove duplicates
if policy[0] == type(new_module):
new_gems = set(new_gems + policy[1])
policy_list[i] = tuple([type(new_module), new_gems])
return policy_list
policy_list.append(tuple([type(new_module), new_gems]))
return policy_list
def tp_parser(model):
policy_list = []
module_list = []
......@@ -60,7 +72,9 @@ class AutoTP():
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
elif isinstance(
submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
......@@ -70,7 +84,9 @@ class AutoTP():
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
layer_list = []
if gem_list != []:
policy_list.append(tuple([type(module), gem_list]))
gem_list = list(set(gem_list))
policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
gem_list = []
return policy_list
......@@ -463,6 +463,8 @@ def replace_transformer_layer(orig_layer_impl,
child.num_heads = child.num_heads // mp_size
if hasattr(child, 'num_attention_heads'):
child.num_attention_heads = child.num_attention_heads // mp_size
if hasattr(child, 'num_attn_heads'):
child.num_attn_heads = child.num_attn_heads // mp_size
if hasattr(child, 'all_head_size'):
child.all_head_size = child.all_head_size // mp_size
if hasattr(child, 'embed_dim'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册