From 2c77c937372225ffcc069729ba6683c10a934223 Mon Sep 17 00:00:00 2001 From: tangwei Date: Tue, 14 Apr 2020 13:17:33 +0800 Subject: [PATCH] code clean --- fleetrec/examples/ctr-dnn_train_single.yaml | 4 ++-- fleetrec/reader/reader.py | 3 ++- fleetrec/trainer/factory.py | 8 ++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/fleetrec/examples/ctr-dnn_train_single.yaml b/fleetrec/examples/ctr-dnn_train_single.yaml index c0d6c6a4..91dca18f 100644 --- a/fleetrec/examples/ctr-dnn_train_single.yaml +++ b/fleetrec/examples/ctr-dnn_train_single.yaml @@ -31,8 +31,8 @@ train: reader: batch_size: 2 - pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py" - train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train" + class: "fleetrec.models.ctr_dnn.data_generator" + train_data_path: "/root/FleetRec/fleetrec/models/ctr_dnn/data/train/" model: models: "fleetrec.models.ctr_dnn.model" diff --git a/fleetrec/reader/reader.py b/fleetrec/reader/reader.py index 6f6bcfd3..cf217e0d 100644 --- a/fleetrec/reader/reader.py +++ b/fleetrec/reader/reader.py @@ -26,7 +26,8 @@ class Reader(dg.MultiSlotDataGenerator): __metaclass__ = abc.ABCMeta def __init__(self, config): - super().__init__() + dg.MultiSlotDataGenerator.__init__(self) + if os.path.exists(config) and os.path.isfile(config): with open(config, 'r') as rb: _config = yaml.load(rb.read(), Loader=yaml.FullLoader) diff --git a/fleetrec/trainer/factory.py b/fleetrec/trainer/factory.py index c3af2f0d..c225cb27 100644 --- a/fleetrec/trainer/factory.py +++ b/fleetrec/trainer/factory.py @@ -30,14 +30,14 @@ class TrainerFactory(object): pass @staticmethod - def _build_trainer(config): + def _build_trainer(config, yaml_path): print(envs.pretty_print_envs(envs.get_global_envs())) train_mode = envs.get_global_env("train.trainer") if train_mode == "SingleTraining": - trainer = SingleTrainer(config) + trainer = SingleTrainer(yaml_path) elif train_mode == "ClusterTraining": - trainer = ClusterTrainer(config) + trainer = ClusterTrainer(yaml_path) elif train_mode == "CtrTrainer": trainer = CtrPaddleTrainer(config) else: @@ -75,7 +75,7 @@ class TrainerFactory(object): if mode == "ClusterTraining" and container == "local" and not instance: trainer = TrainerFactory._build_engine(config) else: - trainer = TrainerFactory._build_trainer(_config) + trainer = TrainerFactory._build_trainer(_config, config) return trainer -- GitLab