提交 02ec66fd 编写于 作者: T tangwei

update setup.py

上级 7c852a31
......@@ -33,15 +33,15 @@ class TrainerFactory(object):
def _build_trainer(config, yaml_path):
print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_global_env("train.trainer")
train_mode = envs.get_runtime_envion("train.trainer")
if train_mode == "SingleTraining":
trainer = SingleTrainer(yaml_path)
elif train_mode == "ClusterTraining":
trainer = ClusterTrainer(yaml_path)
elif train_mode == "CtrTrainer":
elif train_mode == "CtrTraining":
trainer = CtrPaddleTrainer(config)
elif train_mode == "UserDefineTrainer":
elif train_mode == "UserDefineTraining":
train_location = envs.get_global_env("train.location")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
......
......@@ -12,12 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
global_envs = {}
def set_runtime_envions(envs):
assert isinstance(envs, dict)
for k, v in envs.items():
os.environ[k] = v
def get_runtime_envion(key):
return os.getenv(key, None)
def set_global_envs(envs):
assert isinstance(envs, dict)
......@@ -87,4 +98,3 @@ def lazy_instance(package, class_name):
model_package = __import__(package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name)
return instance
......@@ -11,11 +11,10 @@ def run(model_yaml):
trainer.run()
def single_engine(model_yaml):
single_envs = {}
single_envs["singleTraning"] = True
def single_engine(single_envs, model_yaml):
print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
envs.set_runtime_envions(single_envs)
run(model_yaml)
......@@ -47,25 +46,30 @@ if __name__ == "__main__":
if args.engine == "Single":
print("use SingleTraining to run model: {}".format(args.model))
single_engine(args.model)
single_envs = {}
single_envs["train.trainer"] = "SingleTraining"
single_engine(single_envs, args.model)
elif args.engine == "LocalCluster":
print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.server_num"] = 1
cluster_envs["train.worker_num"] = 1
cluster_envs["train.start_port"] = 36001
cluster_envs["train.log_dir"] = "logs"
cluster_envs["train.trainer"] = "SingleTraining"
local_cluster_engine(cluster_envs, args.model)
elif args.engine == "LocalMPI":
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.server_num"] = 1
cluster_envs["train.worker_num"] = 1
cluster_envs["train.start_port"] = 36001
cluster_envs["train.log_dir"] = "logs"
cluster_envs["train.trainer"] = "CtrTraining"
local_mpi_engine(cluster_envs, args.model)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册