未验证 提交 430510bf 编写于 作者: S Satpal Singh Rathore 提交者: GitHub

Checks for user injection policy (#3052)

* check injection policy

* transformers v4

* move check_inference_tuple

* user injection policy check in infer engine

* fix pre-commit format

* fix formatting

* fix clang format

---------
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
Co-authored-by: NLev Kurilenko <113481193+lekurile@users.noreply.github.com>
上级 a23cda6c
......@@ -129,11 +129,21 @@ class InferenceEngine(Module):
# 1. User specified Tensor Parallelism
assert not config.replace_with_kernel_inject, "Cannot use both user specified injection policy and kernel injection"
for client_module, injection_policy in self.injection_dict.items():
assert issubclass(client_module,
torch.nn.Module), f"{client_module} is not a subclass of torch.nn.Module"
# construct the tuple and pass that instead of a string or dict.
if isinstance(injection_policy, str):
config.injection_policy_tuple = (injection_policy, )
else:
config.injection_policy_tuple = injection_policy
layer_names = [name for name, _ in self.module.named_modules()]
for policy in config.injection_policy_tuple:
if not any(name.endswith(policy) for name in layer_names):
raise ValueError(f"Injection policy layer'{policy}' not valid.")
self._apply_injection_policy(config, client_module)
else:
if config.replace_with_kernel_inject:
......
......@@ -10,7 +10,7 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal
""" Quantize bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
model (torch.nn.Module): user's nn.module representing their model
megatron (bool): megatron model-parallel implementation (this is supported for inference only)
......
......@@ -182,7 +182,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
model (torch.nn.Module): user's nn.module representing their model
checkpoint_dict: Dictionary for checkpoint passed from the Inference Engine
config: top-level DS Inference config defined in inference/config.py
......@@ -458,7 +458,7 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
model (torch.nn.Module): user's nn.module representing their model
config (dict): model config containing hidden size, attention heads, etc.
Returns:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册