未验证 提交 4079077c 编写于 作者: M Michael Wyatt 提交者: GitHub

add support for hjson config files (#2783)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 2c6e8194
...@@ -7,6 +7,7 @@ from typing import Union ...@@ -7,6 +7,7 @@ from typing import Union
import torch import torch
import json import json
import hjson
import copy import copy
import base64 import base64
...@@ -705,14 +706,14 @@ class DeepSpeedConfig(object): ...@@ -705,14 +706,14 @@ class DeepSpeedConfig(object):
if isinstance(config, dict): if isinstance(config, dict):
self._param_dict = config self._param_dict = config
elif os.path.exists(config): elif os.path.exists(config):
self._param_dict = json.load( self._param_dict = hjson.load(
open(config, open(config,
"r"), "r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys) object_pairs_hook=dict_raise_error_on_duplicate_keys)
else: else:
try: try:
config_decoded = base64.urlsafe_b64decode(config).decode('utf-8') config_decoded = base64.urlsafe_b64decode(config).decode('utf-8')
self._param_dict = json.loads(config_decoded) self._param_dict = hjson.loads(config_decoded)
except (UnicodeDecodeError, AttributeError): except (UnicodeDecodeError, AttributeError):
raise ValueError( raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}" f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}"
......
# A test on its own # A test on its own
import os
import torch import torch
import pytest import pytest
import json import json
import hjson
import argparse import argparse
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
...@@ -158,11 +160,41 @@ def test_get_bfloat16_enabled(bf16_key): ...@@ -158,11 +160,41 @@ def test_get_bfloat16_enabled(bf16_key):
assert get_bfloat16_enabled(cfg) == True assert get_bfloat16_enabled(cfg) == True
class TestConfigLoad(DistributedTest):
world_size = 1
def test_dict(self, base_config):
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=base_config,
model=model,
model_parameters=model.parameters())
def test_json(self, base_config, tmpdir):
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, 'w') as fp:
json.dump(base_config, fp)
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=config_path,
model=model,
model_parameters=model.parameters())
def test_hjson(self, base_config, tmpdir):
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, 'w') as fp:
hjson.dump(base_config, fp)
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=config_path,
model=model,
model_parameters=model.parameters())
class TestDeprecatedDeepScaleConfig(DistributedTest): class TestDeprecatedDeepScaleConfig(DistributedTest):
world_size = 1 world_size = 1
def test(self, base_config, tmpdir): def test(self, base_config, tmpdir):
config_path = create_config_from_dict(tmpdir, base_config) config_path = create_config_from_dict(tmpdir, base_config)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
args = parser.parse_args(args='') args = parser.parse_args(args='')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册