未验证 提交 c87118b0 编写于 作者: S Stas Bekman 提交者: GitHub

[config] turn exponential notation back on for config dump (#955)

* e-notation for large floats

* handle ints too

* readability

* handle bool
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 adac058a
......@@ -9,7 +9,7 @@ import copy
from .constants import *
from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys, ScientificNotationEncoder
from .zero.config import DeepSpeedZeroConfig
from .zero.constants import *
from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
......@@ -744,6 +744,7 @@ class DeepSpeedConfig(object):
json.dumps(self._param_dict,
sort_keys=True,
indent=4,
cls=ScientificNotationEncoder,
separators=(',',
':'))))
......
......@@ -6,7 +6,40 @@ Licensed under the MIT license.
Collection of DeepSpeed configuration utilities
"""
import json
from collections import Counter
from collections import Counter, Mapping, Sequence
# adapted from https://stackoverflow.com/a/50701137/9201239
class ScientificNotationEncoder(json.JSONEncoder):
"""
This class overrides ``json.dumps`` default formatter.
This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation.
Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it
"""
def iterencode(self, o, _one_shot=False, level=0):
indent = self.indent if self.indent is not None else 4
prefix_close = " " * level * indent
level += 1
prefix = " " * level * indent
if isinstance(o, bool):
return "true" if o else "false"
elif isinstance(o, float) or isinstance(o, int):
if o > 1e3:
return f"{o:e}"
else:
return f"{o}"
elif isinstance(o, Mapping):
x = [
f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k,
v in o.items()
]
return "{" + ', '.join(x) + f"\n{prefix_close}" + "}"
elif isinstance(o, Sequence) and not isinstance(o, str):
return f"[{ f', '.join(map(self.iterencode, o)) }]"
return "\n, ".join(super().iterencode(o, _one_shot))
class DeepSpeedConfigObject(object):
......@@ -17,7 +50,12 @@ class DeepSpeedConfigObject(object):
return self.__dict__
def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)
return json.dumps(
self.__dict__,
sort_keys=True,
indent=4,
cls=ScientificNotationEncoder,
)
def get_scalar_param(param_dict, param_name, param_default_value):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册