提交 2c77c937 编写于 作者: T tangwei

code clean

上级 7f99ff03
...@@ -31,8 +31,8 @@ train: ...@@ -31,8 +31,8 @@ train:
reader: reader:
batch_size: 2 batch_size: 2
pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py" class: "fleetrec.models.ctr_dnn.data_generator"
train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train" train_data_path: "/root/FleetRec/fleetrec/models/ctr_dnn/data/train/"
model: model:
models: "fleetrec.models.ctr_dnn.model" models: "fleetrec.models.ctr_dnn.model"
......
...@@ -26,7 +26,8 @@ class Reader(dg.MultiSlotDataGenerator): ...@@ -26,7 +26,8 @@ class Reader(dg.MultiSlotDataGenerator):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, config): def __init__(self, config):
super().__init__() dg.MultiSlotDataGenerator.__init__(self)
if os.path.exists(config) and os.path.isfile(config): if os.path.exists(config) and os.path.isfile(config):
with open(config, 'r') as rb: with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader) _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
......
...@@ -30,14 +30,14 @@ class TrainerFactory(object): ...@@ -30,14 +30,14 @@ class TrainerFactory(object):
pass pass
@staticmethod @staticmethod
def _build_trainer(config): def _build_trainer(config, yaml_path):
print(envs.pretty_print_envs(envs.get_global_envs())) print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_global_env("train.trainer") train_mode = envs.get_global_env("train.trainer")
if train_mode == "SingleTraining": if train_mode == "SingleTraining":
trainer = SingleTrainer(config) trainer = SingleTrainer(yaml_path)
elif train_mode == "ClusterTraining": elif train_mode == "ClusterTraining":
trainer = ClusterTrainer(config) trainer = ClusterTrainer(yaml_path)
elif train_mode == "CtrTrainer": elif train_mode == "CtrTrainer":
trainer = CtrPaddleTrainer(config) trainer = CtrPaddleTrainer(config)
else: else:
...@@ -75,7 +75,7 @@ class TrainerFactory(object): ...@@ -75,7 +75,7 @@ class TrainerFactory(object):
if mode == "ClusterTraining" and container == "local" and not instance: if mode == "ClusterTraining" and container == "local" and not instance:
trainer = TrainerFactory._build_engine(config) trainer = TrainerFactory._build_engine(config)
else: else:
trainer = TrainerFactory._build_trainer(_config) trainer = TrainerFactory._build_trainer(_config, config)
return trainer return trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册