diff --git a/fleetrec/examples/ctr-dnn_train_single.yaml b/fleetrec/examples/ctr-dnn_train_single.yaml index 985adebe9ea8872dc556ccaf48775f75716b318d..c0d6c6a4010b7c95df8ffad246bd462659cc7c79 100644 --- a/fleetrec/examples/ctr-dnn_train_single.yaml +++ b/fleetrec/examples/ctr-dnn_train_single.yaml @@ -29,12 +29,8 @@ train: epochs: 10 trainer: "SingleTraining" - strategy: - mode: "async" - reader: - mode: "dataset" - batch_size: 2 + batch_size: 2 pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py" train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train" @@ -56,8 +52,6 @@ train: inference: dirname: "models_for_inference" epoch_interval: 4 - feed_varnames: ["C1", "C2", "C3"] - fetch_varnames: "predict" save_last: True evaluate: diff --git a/fleetrec/examples/train.py b/fleetrec/examples/train.py index fb2fd54b2df7cefb1c65679b46092fca986e7c69..835ab193ca6568e9249f9bfa39d3f29d81ac509d 100644 --- a/fleetrec/examples/train.py +++ b/fleetrec/examples/train.py @@ -31,6 +31,6 @@ from fleetrec.trainer.factory import TrainerFactory if __name__ == "__main__": abs_dir = os.path.dirname(os.path.abspath(__file__)) - yaml = os.path.join(abs_dir, 'ctr-dnn_train_cluster.yaml') + yaml = os.path.join(abs_dir, 'ctr-dnn_train_single.yaml') trainer = TrainerFactory.create(yaml) trainer.run() diff --git a/fleetrec/reader/data_loader.py b/fleetrec/reader/data_loader.py deleted file mode 100644 index abf198b97e6e818e1fbe59006f98492640bcee54..0000000000000000000000000000000000000000 --- a/fleetrec/reader/data_loader.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/fleetrec/reader/reader.py b/fleetrec/reader/reader.py index 8e696747f99f63ae2f3323ea851323f1c61d50f1..6f6bcfd3ceac389eaf99659b57dfea52d35e497d 100644 --- a/fleetrec/reader/reader.py +++ b/fleetrec/reader/reader.py @@ -14,13 +14,27 @@ from __future__ import print_function import abc +import os import paddle.fluid.incubate.data_generator as dg +import yaml + +from fleetrec.utils import envs class Reader(dg.MultiSlotDataGenerator): __metaclass__ = abc.ABCMeta + def __init__(self, config): + super().__init__() + if os.path.exists(config) and os.path.isfile(config): + with open(config, 'r') as rb: + _config = yaml.load(rb.read(), Loader=yaml.FullLoader) + else: + raise ValueError("reader config only support yaml") + + envs.set_global_envs(_config) + @abc.abstractmethod def init(self): pass diff --git a/fleetrec/reader/reader_instance.py b/fleetrec/reader/reader_instance.py index 2b771794c23d0eaad22f436358ba72027f250e1a..b4c2d9c15b07e079fafa711f32795de63408651e 100644 --- a/fleetrec/reader/reader_instance.py +++ b/fleetrec/reader/reader_instance.py @@ -16,8 +16,8 @@ import sys from fleetrec.utils.envs import lazy_instance -if len(sys.argv) != 3: - raise ValueError("reader only accept two argument: initialized reader class name and TRAIN/EVALUATE") +if len(sys.argv) != 4: + raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate 3.yaml_abs_path") reader_package = sys.argv[1] @@ -26,6 +26,8 @@ if sys.argv[2] == "TRAIN": else: reader_name = "EvaluateReader" +yaml_abs_path = sys.argv[3] reader_class = lazy_instance(reader_package, reader_name) -reader = reader_class() +reader = reader_class(yaml_abs_path) +reader.init() reader.run_from_stdin() diff --git a/fleetrec/trainer/factory.py b/fleetrec/trainer/factory.py index 74369eec9fc15dd35e67837ec39bfa43267a49be..c3af2f0d7877a4111231205b5b8e9e2911ff4a5b 100644 --- a/fleetrec/trainer/factory.py +++ b/fleetrec/trainer/factory.py @@ -35,9 +35,9 @@ class TrainerFactory(object): train_mode = envs.get_global_env("train.trainer") if train_mode == "SingleTraining": - trainer = SingleTrainer() + trainer = SingleTrainer(config) elif train_mode == "ClusterTraining": - trainer = ClusterTrainer() + trainer = ClusterTrainer(config) elif train_mode == "CtrTrainer": trainer = CtrPaddleTrainer(config) else: diff --git a/fleetrec/trainer/trainer.py b/fleetrec/trainer/trainer.py index b4b3dc57f63e49c494daf66ef8c2c8678e1838ad..0158202b420b40783d7b9bb6f70dce794a9680bd 100755 --- a/fleetrec/trainer/trainer.py +++ b/fleetrec/trainer/trainer.py @@ -28,6 +28,7 @@ class Trainer(object): self._exe = fluid.Executor(self._place) self._exector_context = {} self._context = {'status': 'uninit', 'is_exit': False} + self._config = config def regist_context_processor(self, status_name, processor): """ diff --git a/fleetrec/trainer/transpiler_trainer.py b/fleetrec/trainer/transpiler_trainer.py index 976ce892797478c19074a207a330cfa4aba3ded6..ebdfdefc6b00b7f74d36fcbbd521f5cafd3a5312 100644 --- a/fleetrec/trainer/transpiler_trainer.py +++ b/fleetrec/trainer/transpiler_trainer.py @@ -44,7 +44,7 @@ class TranspileTrainer(Trainer): reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) reader = os.path.join(abs_dir, '..', 'reader_implement.py') - pipe_cmd = "python {} {} {}".format(reader, reader_class, "TRAIN") + pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) train_data_path = envs.get_global_env("train_data_path", None, namespace) dataset = fluid.DatasetFactory().create_dataset()