提交 42347db7 编写于 作者: T tangwei

fix bug

上级 eeaf9942
......@@ -45,7 +45,7 @@ class TrainerFactory(object):
trainer_abs = trainers.get(train_mode, None)
if trainer_abs is None:
if not os.path.exists(train_mode) or os.path.isfile(train_mode):
if not os.path.exists(train_mode) or not os.path.isfile(train_mode):
raise ValueError("trainer {} can not be recognized".format(train_mode))
trainer_abs = train_mode
train_mode = "UserDefineTrainer"
......
......@@ -21,8 +21,18 @@ global_envs = {}
def set_runtime_envions(envs):
assert isinstance(envs, dict)
def fatten_env_namespace(namespace_nests, local_envs):
for k, v in local_envs.items():
if isinstance(v, dict):
nests = copy.deepcopy(namespace_nests)
nests.append(k)
fatten_env_namespace(nests, v)
else:
global_k = ".".join(namespace_nests + [k])
os.environ[global_k] = str(v)
for k, v in envs.items():
os.environ[k] = str(v)
fatten_env_namespace([k], v)
def get_runtime_envion(key):
......
......@@ -3,10 +3,11 @@ trainer:
threads: 4
# for cluster training
strategy: "async"
communicator:
strategy: "async"
send_queue_size: 4
min_send_grad_num_before_recv: 4
thread_pool_size: 5
max_merge_var_num: 4
trainer: "UserDefineTrainer"
location: "/root/FleetRec/fleetrec/examples/user_define_trainer.py"
......@@ -21,8 +21,9 @@ def set_runtime_envs(cluster_envs, engine_yaml):
if cluster_envs is None:
cluster_envs = {}
cluster_envs.update(cluster_envs)
cluster_envs.update(_envs)
envs.set_runtime_envions(cluster_envs)
# envs.set_runtime_envions(cluster_envs)
print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value")))
......@@ -40,7 +41,10 @@ def get_engine(engine):
def single_engine(args):
print("use single engine to run model: {}".format(args.model))
single_envs = {"trainer.trainer": "SingleTrainer", "trainer.threads": "2"}
single_envs = {}
single_envs["trainer.trainer"] = "SingleTrainer"
single_envs["trainer.threads"] = "2"
set_runtime_envs(single_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model)
return trainer
......@@ -49,7 +53,8 @@ def single_engine(args):
def cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "ClusterTrainer"}
cluster_envs = {}
cluster_envs["trainer.trainer"] = "ClusterTrainer"
set_runtime_envs(cluster_envs, args.engine_extras)
envs.set_runtime_envions(cluster_envs)
......@@ -60,7 +65,8 @@ def cluster_engine(args):
def cluster_mpi_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "CtrCodingTrainer"}
cluster_envs = {}
cluster_envs["trainer.trainer"] = "CtrCodingTrainer"
set_runtime_envs(cluster_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model)
......@@ -77,7 +83,9 @@ def local_cluster_engine(args):
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.strategy.mode"] = "async"
cluster_envs["trainer.strategy"] = "async"
cluster_envs["trainer.threads"] = "2"
cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.engine_extras)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册