未验证 提交 4079077c 编写于 作者: M Michael Wyatt 提交者: GitHub

add support for hjson config files (#2783)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 2c6e8194
......@@ -7,6 +7,7 @@ from typing import Union
import torch
import json
import hjson
import copy
import base64
......@@ -705,14 +706,14 @@ class DeepSpeedConfig(object):
if isinstance(config, dict):
self._param_dict = config
elif os.path.exists(config):
self._param_dict = json.load(
self._param_dict = hjson.load(
open(config,
"r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
else:
try:
config_decoded = base64.urlsafe_b64decode(config).decode('utf-8')
self._param_dict = json.loads(config_decoded)
self._param_dict = hjson.loads(config_decoded)
except (UnicodeDecodeError, AttributeError):
raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}"
......
# A test on its own
import os
import torch
import pytest
import json
import hjson
import argparse
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
......@@ -158,11 +160,41 @@ def test_get_bfloat16_enabled(bf16_key):
assert get_bfloat16_enabled(cfg) == True
class TestConfigLoad(DistributedTest):
world_size = 1
def test_dict(self, base_config):
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=base_config,
model=model,
model_parameters=model.parameters())
def test_json(self, base_config, tmpdir):
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, 'w') as fp:
json.dump(base_config, fp)
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=config_path,
model=model,
model_parameters=model.parameters())
def test_hjson(self, base_config, tmpdir):
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, 'w') as fp:
hjson.dump(base_config, fp)
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=config_path,
model=model,
model_parameters=model.parameters())
class TestDeprecatedDeepScaleConfig(DistributedTest):
world_size = 1
def test(self, base_config, tmpdir):
config_path = create_config_from_dict(tmpdir, base_config)
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册