提交 8b9b9cf9 编写于 作者: T tangwei

merge yaml two to one

上级 59f4df30
......@@ -89,7 +89,7 @@ def user_define_engine(engine_yaml):
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None
envs.set_runtime_envions(_config)
envs.set_runtime_environs(_config)
train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location)
......
......@@ -18,13 +18,14 @@ import copy
global_envs = {}
def set_runtime_envions(envs):
def flatten_environs(envs):
flatten_dict = {}
assert isinstance(envs, dict)
def fatten_env_namespace(namespace_nests, local_envs):
if not isinstance(local_envs, dict):
global_k = ".".join(namespace_nests)
os.environ[global_k] = str(local_envs)
flatten_dict[global_k] = str(local_envs)
else:
for k, v in local_envs.items():
if isinstance(v, dict):
......@@ -33,18 +34,24 @@ def set_runtime_envions(envs):
fatten_env_namespace(nests, v)
else:
global_k = ".".join(namespace_nests + [k])
os.environ[global_k] = str(v)
flatten_dict[global_k] = str(v)
for k, v in envs.items():
fatten_env_namespace([k], v)
return flatten_dict
def get_runtime_envion(key):
def set_runtime_environs(environs):
for k, v in environs.items():
os.environ[k] = v
def get_runtime_environ(key):
return os.getenv(key, None)
def get_trainer():
train_mode = get_runtime_envion("trainer.trainer")
train_mode = get_runtime_environ("trainer.trainer")
return train_mode
......
......@@ -13,6 +13,17 @@
# limitations under the License.
train:
trainer:
trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py"
threads: 4
# for cluster training
strategy: "async"
communicator:
send_queue_size: 4
min_send_grad_num_before_recv: 4
thread_pool_size: 5
max_merge_var_num: 4
epochs: 10
reader:
......
trainer:
trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py"
threads: 4
# for cluster training
strategy: "async"
communicator:
send_queue_size: 4
min_send_grad_num_before_recv: 4
thread_pool_size: 5
max_merge_var_num: 4
......@@ -13,16 +13,23 @@ clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
def set_runtime_envs(cluster_envs, engine_yaml):
if engine_yaml is not None:
def get_engine_extras():
with open(engine_yaml, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_envs = {}
flattens = envs.flatten_environs(_envs)
engine_extras = {}
for k, v in flattens.items():
if k.startswith("train.trainer."):
engine_extras[k] = v
return engine_extras
if cluster_envs is None:
cluster_envs = {}
envs.set_runtime_envions(cluster_envs)
envs.set_runtime_envions(_envs)
envs.set_runtime_environs(cluster_envs)
envs.set_runtime_environs(get_engine_extras())
need_print = {}
for k, v in os.environ.items():
......@@ -62,9 +69,8 @@ def cluster_engine(args):
cluster_envs = {}
cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.engine"] = "cluster"
set_runtime_envs(cluster_envs, args.engine_extras)
set_runtime_envs(cluster_envs, args.model)
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model)
return trainer
......@@ -146,11 +152,6 @@ if __name__ == "__main__":
if args.engine.upper() not in clusters:
raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters))
if args.engine_extras is not None:
if not os.path.exists(args.engine_extras) or not os.path.isfile(args.engine_extras):
raise ValueError(
"argument engine_extras: {} error, must specify an existed YAML file".format(args.engine_extras))
which_engine = get_engine(args.engine)
engine = which_engine(args)
engine.run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册