From c87118b0c5c4cfc47446ccd37c22fdc063143fa4 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 14 Apr 2021 07:46:46 -0700 Subject: [PATCH] [config] turn exponential notation back on for config dump (#955) * e-notation for large floats * handle ints too * readability * handle bool Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/config.py | 3 ++- deepspeed/runtime/config_utils.py | 42 +++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 11e1d403..b9ba6cf3 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -9,7 +9,7 @@ import copy from .constants import * from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE -from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys +from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys, ScientificNotationEncoder from .zero.config import DeepSpeedZeroConfig from .zero.constants import * from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig @@ -744,6 +744,7 @@ class DeepSpeedConfig(object): json.dumps(self._param_dict, sort_keys=True, indent=4, + cls=ScientificNotationEncoder, separators=(',', ':')))) diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 62782852..12711d56 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -6,7 +6,40 @@ Licensed under the MIT license. Collection of DeepSpeed configuration utilities """ import json -from collections import Counter +from collections import Counter, Mapping, Sequence + + +# adapted from https://stackoverflow.com/a/50701137/9201239 +class ScientificNotationEncoder(json.JSONEncoder): + """ + This class overrides ``json.dumps`` default formatter. + + This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation. + + Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it + + """ + def iterencode(self, o, _one_shot=False, level=0): + indent = self.indent if self.indent is not None else 4 + prefix_close = " " * level * indent + level += 1 + prefix = " " * level * indent + if isinstance(o, bool): + return "true" if o else "false" + elif isinstance(o, float) or isinstance(o, int): + if o > 1e3: + return f"{o:e}" + else: + return f"{o}" + elif isinstance(o, Mapping): + x = [ + f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k, + v in o.items() + ] + return "{" + ', '.join(x) + f"\n{prefix_close}" + "}" + elif isinstance(o, Sequence) and not isinstance(o, str): + return f"[{ f', '.join(map(self.iterencode, o)) }]" + return "\n, ".join(super().iterencode(o, _one_shot)) class DeepSpeedConfigObject(object): @@ -17,7 +50,12 @@ class DeepSpeedConfigObject(object): return self.__dict__ def __repr__(self): - return json.dumps(self.__dict__, sort_keys=True, indent=4) + return json.dumps( + self.__dict__, + sort_keys=True, + indent=4, + cls=ScientificNotationEncoder, + ) def get_scalar_param(param_dict, param_name, param_default_value): -- GitLab