提交 2e91f58f 编写于 作者: T tangwei

fix import

上级 3eba3369
......@@ -25,7 +25,7 @@ class TrainerFactory(object):
pass
@staticmethod
def _build_trainer(config, yaml_path):
def _build_trainer(yaml_path):
print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_global_env("train.trainer")
......@@ -40,8 +40,8 @@ class TrainerFactory(object):
from fleetrec.core.trainers.cluster_trainer import ClusterTrainer
trainer = ClusterTrainer(yaml_path)
elif train_mode == "CtrTraining":
from fleetrec.core.trainers.ctr_modul_trainer import CtrPaddleTrainer
trainer = CtrPaddleTrainer(config)
from fleetrec.core.trainers.ctr_coding_trainer import CtrPaddleTrainer
trainer = CtrPaddleTrainer(yaml_path)
elif train_mode == "UserDefineTraining":
train_location = envs.get_global_env("train.location")
train_dirname = os.path.dirname(train_location)
......@@ -63,7 +63,7 @@ class TrainerFactory(object):
raise ValueError("fleetrec's config only support yaml")
envs.set_global_envs(_config)
trainer = TrainerFactory._build_trainer(_config, config)
trainer = TrainerFactory._build_trainer(config)
return trainer
......
......@@ -14,6 +14,8 @@
import abc
import time
import yaml
from paddle import fluid
......@@ -28,7 +30,10 @@ class Trainer(object):
self._exe = fluid.Executor(self._place)
self._exector_context = {}
self._context = {'status': 'uninit', 'is_exit': False}
self._config = config
self._config_yaml = config
with open(config, 'r') as rb:
self._config = yaml.load(rb.read(), Loader=yaml.FullLoader)
def regist_context_processor(self, status_name, processor):
"""
......
......@@ -62,7 +62,7 @@ class CtrPaddleTrainer(Trainer):
reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'reader_instance.py')
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config)
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml)
train_data_path = envs.get_global_env("train_data_path", None, namespace)
dataset = fluid.DatasetFactory().create_dataset()
......
......@@ -17,7 +17,6 @@ def run(model_yaml):
def single_engine(single_envs, model_yaml):
print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
envs.set_runtime_envions(single_envs)
run(model_yaml)
......@@ -33,8 +32,8 @@ def local_cluster_engine(cluster_envs, model_yaml):
def local_mpi_engine(cluster_envs, model_yaml):
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
launch = LocalMPIEngine(cluster_envs, model_yaml)
launch.run()
......@@ -79,7 +78,7 @@ if __name__ == "__main__":
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"}
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "LOCAL_CLUSTER":
print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
......@@ -100,7 +99,7 @@ if __name__ == "__main__":
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"}
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "CLUSTER":
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册