diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index dde8cd952e52bc6aa88fd297fd69905492840bdb..3e5d2cfff81c4ef1142a81789f33e965af08680b 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -7,6 +7,7 @@ from typing import Union import torch import json +import hjson import copy import base64 @@ -705,14 +706,14 @@ class DeepSpeedConfig(object): if isinstance(config, dict): self._param_dict = config elif os.path.exists(config): - self._param_dict = json.load( + self._param_dict = hjson.load( open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys) else: try: config_decoded = base64.urlsafe_b64decode(config).decode('utf-8') - self._param_dict = json.loads(config_decoded) + self._param_dict = hjson.loads(config_decoded) except (UnicodeDecodeError, AttributeError): raise ValueError( f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}" diff --git a/tests/unit/runtime/test_ds_config_dict.py b/tests/unit/runtime/test_ds_config_dict.py index 92ea2ce22ac0d1886066330ac03de35a52cadaa1..8371dc7c7a2f1e6383a88a9b9c83e4f3300fd012 100644 --- a/tests/unit/runtime/test_ds_config_dict.py +++ b/tests/unit/runtime/test_ds_config_dict.py @@ -1,7 +1,9 @@ # A test on its own +import os import torch import pytest import json +import hjson import argparse from deepspeed.runtime.zero.config import DeepSpeedZeroConfig @@ -158,11 +160,41 @@ def test_get_bfloat16_enabled(bf16_key): assert get_bfloat16_enabled(cfg) == True +class TestConfigLoad(DistributedTest): + world_size = 1 + + def test_dict(self, base_config): + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(config=base_config, + model=model, + model_parameters=model.parameters()) + + def test_json(self, base_config, tmpdir): + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, 'w') as fp: + json.dump(base_config, fp) + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(config=config_path, + model=model, + model_parameters=model.parameters()) + + def test_hjson(self, base_config, tmpdir): + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, 'w') as fp: + hjson.dump(base_config, fp) + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(config=config_path, + model=model, + model_parameters=model.parameters()) + + class TestDeprecatedDeepScaleConfig(DistributedTest): world_size = 1 def test(self, base_config, tmpdir): - config_path = create_config_from_dict(tmpdir, base_config) parser = argparse.ArgumentParser() args = parser.parse_args(args='')