未验证 提交 c87118b0 编写于 作者: S Stas Bekman 提交者: GitHub

[config] turn exponential notation back on for config dump (#955)

* e-notation for large floats

* handle ints too

* readability

* handle bool
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 adac058a
...@@ -9,7 +9,7 @@ import copy ...@@ -9,7 +9,7 @@ import copy
from .constants import * from .constants import *
from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys, ScientificNotationEncoder
from .zero.config import DeepSpeedZeroConfig from .zero.config import DeepSpeedZeroConfig
from .zero.constants import * from .zero.constants import *
from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
...@@ -744,6 +744,7 @@ class DeepSpeedConfig(object): ...@@ -744,6 +744,7 @@ class DeepSpeedConfig(object):
json.dumps(self._param_dict, json.dumps(self._param_dict,
sort_keys=True, sort_keys=True,
indent=4, indent=4,
cls=ScientificNotationEncoder,
separators=(',', separators=(',',
':')))) ':'))))
......
...@@ -6,7 +6,40 @@ Licensed under the MIT license. ...@@ -6,7 +6,40 @@ Licensed under the MIT license.
Collection of DeepSpeed configuration utilities Collection of DeepSpeed configuration utilities
""" """
import json import json
from collections import Counter from collections import Counter, Mapping, Sequence
# adapted from https://stackoverflow.com/a/50701137/9201239
class ScientificNotationEncoder(json.JSONEncoder):
"""
This class overrides ``json.dumps`` default formatter.
This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation.
Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it
"""
def iterencode(self, o, _one_shot=False, level=0):
indent = self.indent if self.indent is not None else 4
prefix_close = " " * level * indent
level += 1
prefix = " " * level * indent
if isinstance(o, bool):
return "true" if o else "false"
elif isinstance(o, float) or isinstance(o, int):
if o > 1e3:
return f"{o:e}"
else:
return f"{o}"
elif isinstance(o, Mapping):
x = [
f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k,
v in o.items()
]
return "{" + ', '.join(x) + f"\n{prefix_close}" + "}"
elif isinstance(o, Sequence) and not isinstance(o, str):
return f"[{ f', '.join(map(self.iterencode, o)) }]"
return "\n, ".join(super().iterencode(o, _one_shot))
class DeepSpeedConfigObject(object): class DeepSpeedConfigObject(object):
...@@ -17,7 +50,12 @@ class DeepSpeedConfigObject(object): ...@@ -17,7 +50,12 @@ class DeepSpeedConfigObject(object):
return self.__dict__ return self.__dict__
def __repr__(self): def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4) return json.dumps(
self.__dict__,
sort_keys=True,
indent=4,
cls=ScientificNotationEncoder,
)
def get_scalar_param(param_dict, param_name, param_default_value): def get_scalar_param(param_dict, param_name, param_default_value):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册