未验证 提交 e59f8054 编写于 作者: M Michael Wyatt 提交者: GitHub

Fix backward compatibility for InferenceConfig (#2516)

* Make new InferenceConfig backwards compatible with previous init_inference API
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 78d4ca1f
......@@ -51,11 +51,14 @@ jobs:
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
git clone https://github.com/huggingface/accelerate
cd accelerate
# tmp fix
git checkout 5f4ba04628eeea14f9d248ab0e54399899503532
git rev-parse --short HEAD
# installing dependencies
pip install .[testing]
# force protobuf version due to issues
pip install "protobuf<4.21.0"
# tmp fix: force newer datasets version
pip install "datasets>=2.0.0"
#pip install "datasets>=2.0.0"
pip list
HF_DATASETS_CACHE=/blob/datasets_cache/ TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose tests/deepspeed
......@@ -54,7 +54,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout 6268694e2
#git checkout 6268694e2
git rev-parse --short HEAD
# scipy/sklearn required for tests, using the 'dev' extra forces torch re-install
pip install .[testing]
......
import torch
from pydantic import validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
from typing import Dict
from pydantic import validator
from typing import Dict, Union
from enum import Enum
......@@ -155,7 +155,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
Note that the masking is application specific.
"""
moe: DeepSpeedMoEConfig = {}
moe: Union[bool, DeepSpeedMoEConfig] = {}
""" Specify if the type of Transformer is MoE. """
quant: QuantizationConfig = {}
......@@ -231,24 +231,27 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
to the required token-length required for your use-case.
"""
mp_size: int = Field(1,
deprecated=True,
new_param="tensor_parallel",
set_new_param=False)
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
"""
Desired model parallel size, default is 1 meaning no model parallelism.
Deprecated, please use the ``tensor_parallel` config to control model
parallelism.
"""
@validator("mp_size")
def tp_size_set(cls, field_value, values):
print(values["tensor_parallel"].__fields_set__)
if "tp_size" in values["tensor_parallel"].__fields_set__:
assert (
values["tensor_parallel"].tp_size == field_value
), f"Cannot provide different values for mp_size ({field_value}) and tensor_parallel.tp_size ({values['tensor_parallel'].tp_size})"
else:
values["tensor_parallel"].tp_size = field_value
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
ep_group: object = Field(None,
alias="expert_group",
deprecated=True,
new_param="moe.ep_group")
ep_mp_group: object = Field(None,
alias="expert_mp_group",
deprecated=True,
new_param="moe.ep_mp_group")
@validator("moe")
def moe_backward_compat(cls, field_value, values):
if isinstance(field_value, bool):
return DeepSpeedMoEConfig(moe=field_value)
return field_value
class Config:
......
......@@ -8,6 +8,7 @@ Collection of DeepSpeed configuration utilities
import json
import collections
import collections.abc
from functools import reduce
from pydantic import BaseModel
from deepspeed.utils import logger
......@@ -59,32 +60,42 @@ class DeepSpeedConfigModel(BaseModel):
fields_set = pydantic_config.__fields_set__
dep_param = field.name
kwargs = field.field_info.extra
new_param_fn = kwargs.get("new_param_fn", lambda x: x)
param_value = new_param_fn(getattr(pydantic_config, dep_param))
new_param = kwargs.get("new_param", "")
if dep_param in fields_set:
logger.warning(f"Config parameter {dep_param} is deprecated" +
(f" use {new_param} instead" if new_param else ""))
# Check if there is a new param and if it should be set with a value
if new_param and kwargs.get("set_new_param", True):
# If the deprecate field was set and set_new_param is True, set new param value
# Remove the deprecate field if there is a replacing field
try:
delattr(pydantic_config, dep_param)
except Exception as e:
logger.error(f"Tried removing deprecated '{dep_param}' from config")
raise e
# Set new param value
new_param_nested = new_param.split(".")
if len(new_param_nested) > 1:
# If the new param exists in a subconfig, we need to get
# the fields set for that subconfig
pydantic_config = reduce(getattr,
new_param_nested[:-1],
pydantic_config)
fields_set = pydantic_config.__fields_set__
new_param_name = new_param_nested[-1]
assert (
new_param not in fields_set
new_param_name not in fields_set
), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together"
# A custom function for converting the old param value to new param value can be provided
new_param_fn = kwargs.get("new_param_fn", lambda x: x)
param_value = new_param_fn(getattr(pydantic_config, dep_param))
try:
setattr(pydantic_config, new_param, param_value)
setattr(pydantic_config, new_param_name, param_value)
except Exception as e:
logger.error(
f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'"
)
raise e
if new_param:
# Remember to remove the deprecate field if there is a replacing field
try:
delattr(pydantic_config, dep_param)
except Exception as e:
logger.error(f"Tried removing deprecated '{dep_param}' from config")
raise e
def _deprecated_fields_check(self, pydantic_config):
fields = pydantic_config.__fields__
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册