From 2ede0d942ad0319f8bb86f17f01b1474784aa4a0 Mon Sep 17 00:00:00 2001 From: Molly Smith <112220543+molly-smith@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:23:55 -0800 Subject: [PATCH] AutoTP Assert Kernel Injection Support (#2939) * check kernel injection supported models * Clarify why user should use kernel injection --- deepspeed/module_inject/auto_tp.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index d55e204a..a2c570f5 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -4,6 +4,7 @@ import re from torch import nn +from .replace_policy import replace_policies class AutoTP(): @@ -73,15 +74,30 @@ class AutoTP(): policy_list.append(tuple([type(new_module), new_gems])) return policy_list + def kernel_supported(module_list): + policy = [] + for plcy in replace_policies: + # instantiate a throw-away policy in order to populate the _orig_layer_class + _ = plcy(None) + if isinstance(plcy._orig_layer_class, list): + for orig_layer_class in plcy._orig_layer_class: + policy.append(orig_layer_class) + elif plcy._orig_layer_class is not None: + policy.append(plcy._orig_layer_class) + for child in module_list: + if child.__class__ in policy: + return True + return False + def tp_parser(model): policy_list = [] module_list = [] layer_list = [] gem_list = [] - assert AutoTP.supported(model), "Automatic policy not supported for model. Please provide policy." - module_list = AutoTP.get_module_list(model) + assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \ + if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy." for module in module_list: for key, submodule in module._modules.items(): if isinstance(submodule, nn.Linear): @@ -103,5 +119,6 @@ class AutoTP(): gem_list = list(set(gem_list)) policy_list = AutoTP.update_policy_list(policy_list, module, gem_list) gem_list = [] - assert len(policy_list), "Not able to determine model policy automatically. Please provide policy." + assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \ + if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy." return policy_list -- GitLab