未验证 提交 2ede0d94 编写于 作者: M Molly Smith 提交者: GitHub

AutoTP Assert Kernel Injection Support (#2939)

* check kernel injection supported models

* Clarify why user should use kernel injection
上级 4ae3a3da
......@@ -4,6 +4,7 @@
import re
from torch import nn
from .replace_policy import replace_policies
class AutoTP():
......@@ -73,15 +74,30 @@ class AutoTP():
policy_list.append(tuple([type(new_module), new_gems]))
return policy_list
def kernel_supported(module_list):
policy = []
for plcy in replace_policies:
# instantiate a throw-away policy in order to populate the _orig_layer_class
_ = plcy(None)
if isinstance(plcy._orig_layer_class, list):
for orig_layer_class in plcy._orig_layer_class:
policy.append(orig_layer_class)
elif plcy._orig_layer_class is not None:
policy.append(plcy._orig_layer_class)
for child in module_list:
if child.__class__ in policy:
return True
return False
def tp_parser(model):
policy_list = []
module_list = []
layer_list = []
gem_list = []
assert AutoTP.supported(model), "Automatic policy not supported for model. Please provide policy."
module_list = AutoTP.get_module_list(model)
assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
for module in module_list:
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
......@@ -103,5 +119,6 @@ class AutoTP():
gem_list = list(set(gem_list))
policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
gem_list = []
assert len(policy_list), "Not able to determine model policy automatically. Please provide policy."
assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
return policy_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册