From e59f80549e30db0b6c088fd3bb2289ac98d91510 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 17 Nov 2022 13:00:39 -0800 Subject: [PATCH] Fix backward compatibility for InferenceConfig (#2516) * Make new InferenceConfig backwards compatible with previous init_inference API Co-authored-by: Jeff Rasley --- .github/workflows/nv-accelerate-v100.yml | 5 +++- .github/workflows/nv-transformers-v100.yml | 2 +- deepspeed/inference/config.py | 35 ++++++++++++---------- deepspeed/runtime/config_utils.py | 35 ++++++++++++++-------- 4 files changed, 47 insertions(+), 30 deletions(-) diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index ed836aa0..44d12e57 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -51,11 +51,14 @@ jobs: if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi git clone https://github.com/huggingface/accelerate cd accelerate + # tmp fix + git checkout 5f4ba04628eeea14f9d248ab0e54399899503532 + git rev-parse --short HEAD # installing dependencies pip install .[testing] # force protobuf version due to issues pip install "protobuf<4.21.0" # tmp fix: force newer datasets version - pip install "datasets>=2.0.0" + #pip install "datasets>=2.0.0" pip list HF_DATASETS_CACHE=/blob/datasets_cache/ TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose tests/deepspeed diff --git a/.github/workflows/nv-transformers-v100.yml b/.github/workflows/nv-transformers-v100.yml index 55a887c8..9ed60a71 100644 --- a/.github/workflows/nv-transformers-v100.yml +++ b/.github/workflows/nv-transformers-v100.yml @@ -54,7 +54,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 6268694e2 + #git checkout 6268694e2 git rev-parse --short HEAD # scipy/sklearn required for tests, using the 'dev' extra forces torch re-install pip install .[testing] diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 42facc9d..e64e585a 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -1,9 +1,9 @@ import torch -from pydantic import validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.zero.config import DeepSpeedZeroConfig from pydantic import Field -from typing import Dict +from pydantic import validator +from typing import Dict, Union from enum import Enum @@ -155,7 +155,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): Note that the masking is application specific. """ - moe: DeepSpeedMoEConfig = {} + moe: Union[bool, DeepSpeedMoEConfig] = {} """ Specify if the type of Transformer is MoE. """ quant: QuantizationConfig = {} @@ -231,24 +231,27 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): to the required token-length required for your use-case. """ - mp_size: int = Field(1, - deprecated=True, - new_param="tensor_parallel", - set_new_param=False) + mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size") """ Desired model parallel size, default is 1 meaning no model parallelism. Deprecated, please use the ``tensor_parallel` config to control model parallelism. """ - @validator("mp_size") - def tp_size_set(cls, field_value, values): - print(values["tensor_parallel"].__fields_set__) - if "tp_size" in values["tensor_parallel"].__fields_set__: - assert ( - values["tensor_parallel"].tp_size == field_value - ), f"Cannot provide different values for mp_size ({field_value}) and tensor_parallel.tp_size ({values['tensor_parallel'].tp_size})" - else: - values["tensor_parallel"].tp_size = field_value + mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu") + ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size") + ep_group: object = Field(None, + alias="expert_group", + deprecated=True, + new_param="moe.ep_group") + ep_mp_group: object = Field(None, + alias="expert_mp_group", + deprecated=True, + new_param="moe.ep_mp_group") + + @validator("moe") + def moe_backward_compat(cls, field_value, values): + if isinstance(field_value, bool): + return DeepSpeedMoEConfig(moe=field_value) return field_value class Config: diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index aec6faec..81ef972a 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -8,6 +8,7 @@ Collection of DeepSpeed configuration utilities import json import collections import collections.abc +from functools import reduce from pydantic import BaseModel from deepspeed.utils import logger @@ -59,32 +60,42 @@ class DeepSpeedConfigModel(BaseModel): fields_set = pydantic_config.__fields_set__ dep_param = field.name kwargs = field.field_info.extra + new_param_fn = kwargs.get("new_param_fn", lambda x: x) + param_value = new_param_fn(getattr(pydantic_config, dep_param)) new_param = kwargs.get("new_param", "") if dep_param in fields_set: logger.warning(f"Config parameter {dep_param} is deprecated" + (f" use {new_param} instead" if new_param else "")) + # Check if there is a new param and if it should be set with a value if new_param and kwargs.get("set_new_param", True): - # If the deprecate field was set and set_new_param is True, set new param value + # Remove the deprecate field if there is a replacing field + try: + delattr(pydantic_config, dep_param) + except Exception as e: + logger.error(f"Tried removing deprecated '{dep_param}' from config") + raise e + + # Set new param value + new_param_nested = new_param.split(".") + if len(new_param_nested) > 1: + # If the new param exists in a subconfig, we need to get + # the fields set for that subconfig + pydantic_config = reduce(getattr, + new_param_nested[:-1], + pydantic_config) + fields_set = pydantic_config.__fields_set__ + new_param_name = new_param_nested[-1] assert ( - new_param not in fields_set + new_param_name not in fields_set ), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together" # A custom function for converting the old param value to new param value can be provided - new_param_fn = kwargs.get("new_param_fn", lambda x: x) - param_value = new_param_fn(getattr(pydantic_config, dep_param)) try: - setattr(pydantic_config, new_param, param_value) + setattr(pydantic_config, new_param_name, param_value) except Exception as e: logger.error( f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'" ) raise e - if new_param: - # Remember to remove the deprecate field if there is a replacing field - try: - delattr(pydantic_config, dep_param) - except Exception as e: - logger.error(f"Tried removing deprecated '{dep_param}' from config") - raise e def _deprecated_fields_check(self, pydantic_config): fields = pydantic_config.__fields__ -- GitLab