提交 8b9b9cf9 编写于 作者: T tangwei

merge yaml two to one

上级 59f4df30
...@@ -89,7 +89,7 @@ def user_define_engine(engine_yaml): ...@@ -89,7 +89,7 @@ def user_define_engine(engine_yaml):
_config = yaml.load(rb.read(), Loader=yaml.FullLoader) _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None assert _config is not None
envs.set_runtime_envions(_config) envs.set_runtime_environs(_config)
train_location = envs.get_global_env("engine.file") train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location) train_dirname = os.path.dirname(train_location)
......
...@@ -18,13 +18,14 @@ import copy ...@@ -18,13 +18,14 @@ import copy
global_envs = {} global_envs = {}
def set_runtime_envions(envs): def flatten_environs(envs):
flatten_dict = {}
assert isinstance(envs, dict) assert isinstance(envs, dict)
def fatten_env_namespace(namespace_nests, local_envs): def fatten_env_namespace(namespace_nests, local_envs):
if not isinstance(local_envs, dict): if not isinstance(local_envs, dict):
global_k = ".".join(namespace_nests) global_k = ".".join(namespace_nests)
os.environ[global_k] = str(local_envs) flatten_dict[global_k] = str(local_envs)
else: else:
for k, v in local_envs.items(): for k, v in local_envs.items():
if isinstance(v, dict): if isinstance(v, dict):
...@@ -33,18 +34,24 @@ def set_runtime_envions(envs): ...@@ -33,18 +34,24 @@ def set_runtime_envions(envs):
fatten_env_namespace(nests, v) fatten_env_namespace(nests, v)
else: else:
global_k = ".".join(namespace_nests + [k]) global_k = ".".join(namespace_nests + [k])
os.environ[global_k] = str(v) flatten_dict[global_k] = str(v)
for k, v in envs.items(): for k, v in envs.items():
fatten_env_namespace([k], v) fatten_env_namespace([k], v)
return flatten_dict
def get_runtime_envion(key):
def set_runtime_environs(environs):
for k, v in environs.items():
os.environ[k] = v
def get_runtime_environ(key):
return os.getenv(key, None) return os.getenv(key, None)
def get_trainer(): def get_trainer():
train_mode = get_runtime_envion("trainer.trainer") train_mode = get_runtime_environ("trainer.trainer")
return train_mode return train_mode
......
...@@ -13,6 +13,17 @@ ...@@ -13,6 +13,17 @@
# limitations under the License. # limitations under the License.
train: train:
trainer:
trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py"
threads: 4
# for cluster training
strategy: "async"
communicator:
send_queue_size: 4
min_send_grad_num_before_recv: 4
thread_pool_size: 5
max_merge_var_num: 4
epochs: 10 epochs: 10
reader: reader:
......
trainer:
trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py"
threads: 4
# for cluster training
strategy: "async"
communicator:
send_queue_size: 4
min_send_grad_num_before_recv: 4
thread_pool_size: 5
max_merge_var_num: 4
...@@ -13,16 +13,23 @@ clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"] ...@@ -13,16 +13,23 @@ clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
def set_runtime_envs(cluster_envs, engine_yaml): def set_runtime_envs(cluster_envs, engine_yaml):
if engine_yaml is not None: def get_engine_extras():
with open(engine_yaml, 'r') as rb: with open(engine_yaml, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader) _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_envs = {} flattens = envs.flatten_environs(_envs)
engine_extras = {}
for k, v in flattens.items():
if k.startswith("train.trainer."):
engine_extras[k] = v
return engine_extras
if cluster_envs is None: if cluster_envs is None:
cluster_envs = {} cluster_envs = {}
envs.set_runtime_envions(cluster_envs)
envs.set_runtime_envions(_envs) envs.set_runtime_environs(cluster_envs)
envs.set_runtime_environs(get_engine_extras())
need_print = {} need_print = {}
for k, v in os.environ.items(): for k, v in os.environ.items():
...@@ -62,9 +69,8 @@ def cluster_engine(args): ...@@ -62,9 +69,8 @@ def cluster_engine(args):
cluster_envs = {} cluster_envs = {}
cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.engine"] = "cluster" cluster_envs["trainer.engine"] = "cluster"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.model)
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
...@@ -146,11 +152,6 @@ if __name__ == "__main__": ...@@ -146,11 +152,6 @@ if __name__ == "__main__":
if args.engine.upper() not in clusters: if args.engine.upper() not in clusters:
raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters)) raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters))
if args.engine_extras is not None:
if not os.path.exists(args.engine_extras) or not os.path.isfile(args.engine_extras):
raise ValueError(
"argument engine_extras: {} error, must specify an existed YAML file".format(args.engine_extras))
which_engine = get_engine(args.engine) which_engine = get_engine(args.engine)
engine = which_engine(args) engine = which_engine(args)
engine.run() engine.run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册