未验证 提交 b5d18a6a 编写于 作者: A Ammar Ahmad Awan 提交者: GitHub

DeepSpeed inference config. (#2459) (#2472)

Changes to inference API to use accept a config dict and cleaning up Inference Engine to utilize the newly added inference config.
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>
上级 a4ceabb6
......@@ -17,7 +17,7 @@ from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpe
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.pipe.engine import PipelineEngine
from .inference.engine import InferenceEngine
from .inference.config import DeepSpeedInferenceConfig
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
......@@ -222,79 +222,46 @@ def add_config_arguments(parser):
return parser
def init_inference(model,
triangular_masking=True,
mp_size=1,
training_mp_size=1,
mpu=None,
ep_group=None,
expert_mp_group=None,
checkpoint=None,
dtype=None,
injection_policy=None,
replace_method='auto',
quantization_setting=None,
replace_with_kernel_inject=False,
return_tuple=True,
ep_size=1,
moe=False,
moe_experts=1,
moe_type='standard',
args=None,
enable_cuda_graph=False,
save_mp_checkpoint_path=None,
base_dir="",
max_tokens=1024):
def default_inference_config():
"""
Return a default DeepSpeed inference configuration dictionary.
"""
return DeepSpeedInferenceConfig().dict()
def init_inference(model, config=None, **kwargs):
"""Initialize the DeepSpeed InferenceEngine.
Arguments:
model: Required: nn.module class before apply any wrappers
Description: all four cases are valid and supported in DS init_inference() API.
triangular_masking: Required: this shows the type of masking for attention scores in transformer layer
note that the masking is application specific.
# Case 1: user provides no config and no kwargs. Default config will be used.
generator.model = deepspeed.init_inference(generator.model)
string = generator("DeepSpeed is")
print(string)
mp_size: Optional: Desired model parallel size, default is 1 meaning no
model parallelism.
# Case 2: user provides a config and no kwargs. User supplied config will be used.
generator.model = deepspeed.init_inference(generator.model, config=config)
string = generator("DeepSpeed is")
print(string)
training_mp_size: Optional: if loading a checkpoint this is the mp size that it was trained with,
it may be different than what the mp size that you want to use during inference.
# Case 3: user provides no config and uses keyword arguments (kwargs) only.
generator.model = deepspeed.init_inference(generator.model,
mp_size=world_size,
dtype=torch.half,
replace_with_kernel_inject=True)
string = generator("DeepSpeed is")
print(string)
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()
# Case 4: user provides config and keyword arguments (kwargs). Both config and kwargs are merged and kwargs take precedence.
generator.model = deepspeed.init_inference(generator.model, config={"dtype": torch.half}, replace_with_kernel_inject=True)
string = generator("DeepSpeed is")
print(string)
Arguments:
model: Required: original nn.module object without any wrappers
config: Optional: instead of arguments, you can pass in a DS inference config dict
checkpoint: Optional: Path to deepspeed compatible checkpoint or path to
JSON with load policy.
dtype: Optional: Desired model data type, will convert model to this type.
Supported target types: torch.half, torch.int8, torch.float
injection_policy: Optional: Dictionary mapping a client nn.Module to its corresponding
injection policy. e.g., {BertLayer : deepspeed.inference.HFBertLayerPolicy}
replace_method: Optional: If 'auto' DeepSpeed will automatically try and replace
model modules with its optimized versions. If an injection_policy is set this will
override the automatic replacement behavior.
quantization_setting: Optional: Quantization settings used for quantizing your model using the MoQ.
The setting can be one element or a tuple. If one value is passed in, we consider it as the number
of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping
for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for
all the network except the MLP part that we use 8 extra grouping).
replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise,
the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection)
return_tuple: Specify whether or not the transformer layers need to return a tuple or a Tensor. It is set to True by default (returning a tuple).
ep_size: The expert-parallelism size which is used for partitioning the experts across the GPUs in the expert-parallel group.
moe: Specify if the type of Transformer is MoE. It is set to False by default.
moe_experts: The global number of experts used in an MoE layer.
moe_type: Specify the type of MoE layer. We have two types of MoE layer: 'Standard' and 'Residual'. It is set to 'Standard' type by default.
args: All the arguments used for launching the inference api that can be useful at the inference-engine for injecting the optimizations.
enable_cuda_graph: use this flag for capturing the CUDA-Graph of the inference ops, so that it can run faster using the graph replay method,
this is set to False by default
save_mp_checkpoint_path: The path for which we want to save the loaded model with a checkpoint. This feature is used for adjusting the
parallelism degree to help alleviate the model loading overhead. It does not save any new checkpoint if no path is passed.
base_dir: This shows the root directory under which all the checkpoint files exists. This can be passed through the json config too.
max_tokens: This argument shows the maximum number of tokens inference-engine can work with, including the input and output tokens.
Please consider increasing it to the required token-length required for your use-case.
Returns:
A deepspeed.InferenceEngine wrapped model.
"""
......@@ -304,28 +271,20 @@ def init_inference(model,
__git_branch__),
ranks=[0])
engine = InferenceEngine(model,
triangular_masking,
mp_size,
training_mp_size,
ep_size,
mpu,
ep_group,
expert_mp_group,
checkpoint,
dtype,
injection_policy,
return_tuple,
replace_method,
quantization_setting,
replace_with_kernel_inject,
moe,
moe_experts,
moe_type,
args,
enable_cuda_graph,
save_mp_checkpoint_path,
base_dir,
max_tokens)
# User did not pass a config, use defaults
if config is None:
config_dict = kwargs
else:
config_dict = config
# if config and kwargs both are passed, merge them, and overwrite using kwargs
if config and kwargs:
config_dict = {}
config_dict.update(config)
config_dict.update(kwargs)
ds_inference_config = DeepSpeedInferenceConfig(**config_dict)
engine = InferenceEngine(model, config=ds_inference_config)
return engine
import torch
from pydantic import validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
from typing import Dict
from enum import Enum
class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
int8 = torch.int8, "torch.int8", "int8"
# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj
def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
class MoETypeEnum(str, Enum):
residual = "residual"
standard = "standard"
class DeepSpeedTPConfig(DeepSpeedConfigModel):
""" Configure tensor parallelism settings """
enabled: bool = True
""" Turn tensor parallelism on/off. """
tp_size: int = 1
""" Number of devices to split the model across using tensor parallelism. """
mpu: object = None
"""
A model parallelism unit object that implements
``get_{model,data}_parallel_{rank,group,world_size}()``.
"""
tp_group: object = None
class DeepSpeedMoEConfig(DeepSpeedConfigModel):
""" Sets parameters for MoE """
enabled: bool = True
ep_size: int = 1
"""
The expert-parallelism size which is used for partitioning the experts
across the GPUs in the expert-parallel group.
"""
moe_experts: list = Field([1], alias="num_experts")
""" The global number of experts used in an MoE layer. """
moe_type: MoETypeEnum = MoETypeEnum.standard
"""
Specify the type of MoE layer. We have two types of MoE layer: 'Standard'
and 'Residual'.
"""
ep_mp_group: object = None
ep_group: object = Field(None, alias="expert_group")
class QuantTypeEnum(str, Enum):
asym = "asymmetric"
sym = "symmetric"
class BaseQuantConfig(DeepSpeedConfigModel):
enabled = True
num_bits = 8
q_type: QuantTypeEnum = QuantTypeEnum.sym
q_groups: int = 1
class WeightQuantConfig(BaseQuantConfig):
enabled = True
class ActivationQuantConfig(BaseQuantConfig):
enabled = True
class QKVQuantConfig(DeepSpeedConfigModel):
enabled = True
class QuantizationConfig(DeepSpeedConfigModel):
enabled: bool = True
activation: ActivationQuantConfig = ActivationQuantConfig()
weight: WeightQuantConfig = WeightQuantConfig()
qkv: QKVQuantConfig = QKVQuantConfig()
# todo: brainstorm on how to do ckpt loading for DS inference
class InferenceCheckpointConfig(DeepSpeedConfigModel):
checkpoint_dir: str = None
save_mp_checkpoint_path: str = None
base_dir: str = None
class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
""" Sets parameters for DeepSpeed Inference Engine. """
replace_with_kernel_inject: bool = Field(False, alias="kernel_inject")
"""
Set to true to inject inference kernels for models such as, Bert, GPT2,
GPT-Neo and GPT-J. Otherwise, the injection_dict provides the names of two
linear layers as a tuple: (attention_output projection, transformer output
projection)
"""
dtype: DtypeEnum = torch.float16
"""
Desired model data type, will convert model to this type.
Supported target types: torch.half, torch.int8, torch.float
"""
tensor_parallel: DeepSpeedTPConfig = Field({}, alias="tp")
"""
Configuration for tensor parallelism used to split the model across several GPUs.
"""
enable_cuda_graph: bool = False
"""
Use this flag for capturing the CUDA-Graph of the inference ops, so that it
can run faster using the graph replay method.
"""
zero: DeepSpeedZeroConfig = {}
""" ZeRO configuration to use with the Inference Engine. """
triangular_masking: bool = Field(True, alias="tm")
"""
Controls the type of masking for attention scores in transformer layer.
Note that the masking is application specific.
"""
moe: DeepSpeedMoEConfig = {}
""" Specify if the type of Transformer is MoE. """
quant: QuantizationConfig = {}
"""
NOTE: only works for int8 dtype.
Quantization settings used for quantizing your model using the MoQ. The
setting can be one element or a tuple. If one value is passed in, we
consider it as the number of groups used in quantization. A tuple is passed
in if we want to mention that there is extra-grouping for the MLP part of a
Transformer layer (e.g. (True, 8) shows we quantize the model using 8
groups for all the network except the MLP part that we use 8 extra
grouping).
"""
#todo: refactor the following 3 into the new checkpoint_config
checkpoint: str = None
"""
Path to deepspeed compatible checkpoint or path to JSON with load policy.
"""
base_dir: str = None
"""
This shows the root directory under which all the checkpoint files exists.
This can be passed through the json config too.
"""
save_mp_checkpoint_path: str = None
"""
The path for which we want to save the loaded model with a checkpoint. This
feature is used for adjusting the parallelism degree to help alleviate the
model loading overhead. It does not save any new checkpoint if no path is
passed.
"""
checkpoint_config: InferenceCheckpointConfig = Field({}, alias="ckpt_config")
""" TODO: Add docs """
return_tuple: bool = True
"""
Specify whether or not the transformer layers need to return a tuple or a
Tensor.
"""
training_mp_size: int = 1
"""
If loading a checkpoint this is the mp size that it was trained with, it
may be different than what the mp size that you want to use during
inference.
"""
replace_method: str = "auto"
"""
If 'auto' DeepSpeed will automatically try and replace model modules with
its optimized versions. If an injection_policy is set this will override
the automatic replacement behavior.
"""
injection_policy: Dict = Field(None, alias="injection_dict")
"""
Dictionary mapping a client nn.Module to its corresponding injection
policy. e.g., {BertLayer : deepspeed.inference.HFBertLayerPolicy}
"""
injection_policy_tuple: tuple = None
""" TODO: Add docs """
config: Dict = None # todo: really no need for this field if we can refactor
max_out_tokens: int = 1024
"""
This argument shows the maximum number of tokens inference-engine can work
with, including the input and output tokens. Please consider increasing it
to the required token-length required for your use-case.
"""
mp_size: int = Field(1,
deprecated=True,
new_param="tensor_parallel",
set_new_param=False)
"""
Desired model parallel size, default is 1 meaning no model parallelism.
Deprecated, please use the ``tensor_parallel` config to control model
parallelism.
"""
@validator("mp_size")
def tp_size_set(cls, field_value, values):
print(values["tensor_parallel"].__fields_set__)
if "tp_size" in values["tensor_parallel"].__fields_set__:
assert (
values["tensor_parallel"].tp_size == field_value
), f"Cannot provide different values for mp_size ({field_value}) and tensor_parallel.tp_size ({values['tensor_parallel'].tp_size})"
else:
values["tensor_parallel"].tp_size = field_value
return field_value
class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}
此差异已折叠。
......@@ -423,6 +423,7 @@ def replace_transformer_layer(orig_layer_impl,
if moe:
ep_world_size = dist.get_world_size()
local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size
bigscience_bloom = policy_cls is BLOOMLayerPolicy
transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig(
hidden_size=hidden_size,
......@@ -796,9 +797,7 @@ def replace_transformer_layer(orig_layer_impl,
weight_shape = child.weight.ds_shape
else:
weight_shape = child.weight.shape
if (isinstance(all_reduce_linears,
tuple) or isinstance(all_reduce_linears,
str)) and name in all_reduce_linears:
if name in all_reduce_linears:
new_weight = torch.empty((
weight_shape[1] if conv_linear_layer else weight_shape[0],
(weight_shape[0] if conv_linear_layer else weight_shape[1]) //
......
......@@ -55,17 +55,20 @@ class DeepSpeedConfigModel(BaseModel):
self._deprecated_fields_check(self)
def _process_deprecated_field(self, pydantic_config, field):
# Get information about the deprecated field
fields_set = pydantic_config.__fields_set__
dep_param = field.name
kwargs = field.field_info.extra
new_param = kwargs.get("new_param", "")
if dep_param in fields_set:
kwargs = field.field_info.extra
new_param = kwargs.get("new_param", "")
logger.warning(f"Config parameter {dep_param} is deprecated" +
(f" use {new_param} instead" if new_param else ""))
if new_param and kwargs.get("set_new_param", True):
# If the deprecate field was set and set_new_param is True, set new param value
assert (
new_param not in fields_set
), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together"
# A custom function for converting the old param value to new param value can be provided
new_param_fn = kwargs.get("new_param_fn", lambda x: x)
param_value = new_param_fn(getattr(pydantic_config, dep_param))
try:
......@@ -75,6 +78,13 @@ class DeepSpeedConfigModel(BaseModel):
f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'"
)
raise e
if new_param:
# Remember to remove the deprecate field if there is a replacing field
try:
delattr(pydantic_config, dep_param)
except Exception as e:
logger.error(f"Tried removing deprecated '{dep_param}' from config")
raise e
def _deprecated_fields_check(self, pydantic_config):
fields = pydantic_config.__fields__
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册