提交 42347db7 编写于 作者: T tangwei

fix bug

上级 eeaf9942
...@@ -45,7 +45,7 @@ class TrainerFactory(object): ...@@ -45,7 +45,7 @@ class TrainerFactory(object):
trainer_abs = trainers.get(train_mode, None) trainer_abs = trainers.get(train_mode, None)
if trainer_abs is None: if trainer_abs is None:
if not os.path.exists(train_mode) or os.path.isfile(train_mode): if not os.path.exists(train_mode) or not os.path.isfile(train_mode):
raise ValueError("trainer {} can not be recognized".format(train_mode)) raise ValueError("trainer {} can not be recognized".format(train_mode))
trainer_abs = train_mode trainer_abs = train_mode
train_mode = "UserDefineTrainer" train_mode = "UserDefineTrainer"
......
...@@ -21,8 +21,18 @@ global_envs = {} ...@@ -21,8 +21,18 @@ global_envs = {}
def set_runtime_envions(envs): def set_runtime_envions(envs):
assert isinstance(envs, dict) assert isinstance(envs, dict)
def fatten_env_namespace(namespace_nests, local_envs):
for k, v in local_envs.items():
if isinstance(v, dict):
nests = copy.deepcopy(namespace_nests)
nests.append(k)
fatten_env_namespace(nests, v)
else:
global_k = ".".join(namespace_nests + [k])
os.environ[global_k] = str(v)
for k, v in envs.items(): for k, v in envs.items():
os.environ[k] = str(v) fatten_env_namespace([k], v)
def get_runtime_envion(key): def get_runtime_envion(key):
......
...@@ -3,10 +3,11 @@ trainer: ...@@ -3,10 +3,11 @@ trainer:
threads: 4 threads: 4
# for cluster training # for cluster training
strategy: "async"
communicator: communicator:
strategy: "async"
send_queue_size: 4 send_queue_size: 4
min_send_grad_num_before_recv: 4 min_send_grad_num_before_recv: 4
thread_pool_size: 5 thread_pool_size: 5
max_merge_var_num: 4 max_merge_var_num: 4
trainer: "UserDefineTrainer"
location: "/root/FleetRec/fleetrec/examples/user_define_trainer.py"
...@@ -21,8 +21,9 @@ def set_runtime_envs(cluster_envs, engine_yaml): ...@@ -21,8 +21,9 @@ def set_runtime_envs(cluster_envs, engine_yaml):
if cluster_envs is None: if cluster_envs is None:
cluster_envs = {} cluster_envs = {}
cluster_envs.update(cluster_envs)
cluster_envs.update(_envs) cluster_envs.update(_envs)
envs.set_runtime_envions(cluster_envs) # envs.set_runtime_envions(cluster_envs)
print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value"))) print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value")))
...@@ -40,7 +41,10 @@ def get_engine(engine): ...@@ -40,7 +41,10 @@ def get_engine(engine):
def single_engine(args): def single_engine(args):
print("use single engine to run model: {}".format(args.model)) print("use single engine to run model: {}".format(args.model))
single_envs = {"trainer.trainer": "SingleTrainer", "trainer.threads": "2"}
single_envs = {}
single_envs["trainer.trainer"] = "SingleTrainer"
single_envs["trainer.threads"] = "2"
set_runtime_envs(single_envs, args.engine_extras) set_runtime_envs(single_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
...@@ -49,7 +53,8 @@ def single_engine(args): ...@@ -49,7 +53,8 @@ def single_engine(args):
def cluster_engine(args): def cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model)) print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "ClusterTrainer"} cluster_envs = {}
cluster_envs["trainer.trainer"] = "ClusterTrainer"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
envs.set_runtime_envions(cluster_envs) envs.set_runtime_envions(cluster_envs)
...@@ -60,7 +65,8 @@ def cluster_engine(args): ...@@ -60,7 +65,8 @@ def cluster_engine(args):
def cluster_mpi_engine(args): def cluster_mpi_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model)) print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "CtrCodingTrainer"} cluster_envs = {}
cluster_envs["trainer.trainer"] = "CtrCodingTrainer"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
...@@ -77,7 +83,9 @@ def local_cluster_engine(args): ...@@ -77,7 +83,9 @@ def local_cluster_engine(args):
cluster_envs["start_port"] = 36001 cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.strategy.mode"] = "async" cluster_envs["trainer.strategy"] = "async"
cluster_envs["trainer.threads"] = "2"
cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册