diff --git a/fleetrec/examples/ctr-dnn_train_single.yaml b/fleetrec/examples/ctr-dnn_train_single.yaml index c0d6c6a4010b7c95df8ffad246bd462659cc7c79..91dca18f3dca5c029913aa2e4a35bb58962438a0 100644 --- a/fleetrec/examples/ctr-dnn_train_single.yaml +++ b/fleetrec/examples/ctr-dnn_train_single.yaml @@ -31,8 +31,8 @@ train: reader: batch_size: 2 - pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py" - train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train" + class: "fleetrec.models.ctr_dnn.data_generator" + train_data_path: "/root/FleetRec/fleetrec/models/ctr_dnn/data/train/" model: models: "fleetrec.models.ctr_dnn.model" diff --git a/fleetrec/reader/reader.py b/fleetrec/reader/reader.py index 6f6bcfd3ceac389eaf99659b57dfea52d35e497d..cf217e0dc91fe6d43db3bf001c43306d55e31537 100644 --- a/fleetrec/reader/reader.py +++ b/fleetrec/reader/reader.py @@ -26,7 +26,8 @@ class Reader(dg.MultiSlotDataGenerator): __metaclass__ = abc.ABCMeta def __init__(self, config): - super().__init__() + dg.MultiSlotDataGenerator.__init__(self) + if os.path.exists(config) and os.path.isfile(config): with open(config, 'r') as rb: _config = yaml.load(rb.read(), Loader=yaml.FullLoader) diff --git a/fleetrec/trainer/factory.py b/fleetrec/trainer/factory.py index c3af2f0d7877a4111231205b5b8e9e2911ff4a5b..c225cb272ae848282002b5a234278d903f1d1d97 100644 --- a/fleetrec/trainer/factory.py +++ b/fleetrec/trainer/factory.py @@ -30,14 +30,14 @@ class TrainerFactory(object): pass @staticmethod - def _build_trainer(config): + def _build_trainer(config, yaml_path): print(envs.pretty_print_envs(envs.get_global_envs())) train_mode = envs.get_global_env("train.trainer") if train_mode == "SingleTraining": - trainer = SingleTrainer(config) + trainer = SingleTrainer(yaml_path) elif train_mode == "ClusterTraining": - trainer = ClusterTrainer(config) + trainer = ClusterTrainer(yaml_path) elif train_mode == "CtrTrainer": trainer = CtrPaddleTrainer(config) else: @@ -75,7 +75,7 @@ class TrainerFactory(object): if mode == "ClusterTraining" and container == "local" and not instance: trainer = TrainerFactory._build_engine(config) else: - trainer = TrainerFactory._build_trainer(_config) + trainer = TrainerFactory._build_trainer(_config, config) return trainer