From 7f99ff033dcc7c22c030df852a19cbdecf392637 Mon Sep 17 00:00:00 2001 From: tangwei Date: Tue, 14 Apr 2020 13:01:41 +0800 Subject: [PATCH] code clean --- fleetrec/examples/ctr-dnn_train_single.yaml | 8 +------- fleetrec/examples/train.py | 2 +- fleetrec/reader/data_loader.py | 13 ------------- fleetrec/reader/reader.py | 14 ++++++++++++++ fleetrec/reader/reader_instance.py | 8 +++++--- fleetrec/trainer/factory.py | 4 ++-- fleetrec/trainer/trainer.py | 1 + fleetrec/trainer/transpiler_trainer.py | 2 +- 8 files changed, 25 insertions(+), 27 deletions(-) delete mode 100644 fleetrec/reader/data_loader.py diff --git a/fleetrec/examples/ctr-dnn_train_single.yaml b/fleetrec/examples/ctr-dnn_train_single.yaml index 985adebe..c0d6c6a4 100644 --- a/fleetrec/examples/ctr-dnn_train_single.yaml +++ b/fleetrec/examples/ctr-dnn_train_single.yaml @@ -29,12 +29,8 @@ train: epochs: 10 trainer: "SingleTraining" - strategy: - mode: "async" - reader: - mode: "dataset" - batch_size: 2 + batch_size: 2 pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py" train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train" @@ -56,8 +52,6 @@ train: inference: dirname: "models_for_inference" epoch_interval: 4 - feed_varnames: ["C1", "C2", "C3"] - fetch_varnames: "predict" save_last: True evaluate: diff --git a/fleetrec/examples/train.py b/fleetrec/examples/train.py index fb2fd54b..835ab193 100644 --- a/fleetrec/examples/train.py +++ b/fleetrec/examples/train.py @@ -31,6 +31,6 @@ from fleetrec.trainer.factory import TrainerFactory if __name__ == "__main__": abs_dir = os.path.dirname(os.path.abspath(__file__)) - yaml = os.path.join(abs_dir, 'ctr-dnn_train_cluster.yaml') + yaml = os.path.join(abs_dir, 'ctr-dnn_train_single.yaml') trainer = TrainerFactory.create(yaml) trainer.run() diff --git a/fleetrec/reader/data_loader.py b/fleetrec/reader/data_loader.py deleted file mode 100644 index abf198b9..00000000 --- a/fleetrec/reader/data_loader.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/fleetrec/reader/reader.py b/fleetrec/reader/reader.py index 8e696747..6f6bcfd3 100644 --- a/fleetrec/reader/reader.py +++ b/fleetrec/reader/reader.py @@ -14,13 +14,27 @@ from __future__ import print_function import abc +import os import paddle.fluid.incubate.data_generator as dg +import yaml + +from fleetrec.utils import envs class Reader(dg.MultiSlotDataGenerator): __metaclass__ = abc.ABCMeta + def __init__(self, config): + super().__init__() + if os.path.exists(config) and os.path.isfile(config): + with open(config, 'r') as rb: + _config = yaml.load(rb.read(), Loader=yaml.FullLoader) + else: + raise ValueError("reader config only support yaml") + + envs.set_global_envs(_config) + @abc.abstractmethod def init(self): pass diff --git a/fleetrec/reader/reader_instance.py b/fleetrec/reader/reader_instance.py index 2b771794..b4c2d9c1 100644 --- a/fleetrec/reader/reader_instance.py +++ b/fleetrec/reader/reader_instance.py @@ -16,8 +16,8 @@ import sys from fleetrec.utils.envs import lazy_instance -if len(sys.argv) != 3: - raise ValueError("reader only accept two argument: initialized reader class name and TRAIN/EVALUATE") +if len(sys.argv) != 4: + raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate 3.yaml_abs_path") reader_package = sys.argv[1] @@ -26,6 +26,8 @@ if sys.argv[2] == "TRAIN": else: reader_name = "EvaluateReader" +yaml_abs_path = sys.argv[3] reader_class = lazy_instance(reader_package, reader_name) -reader = reader_class() +reader = reader_class(yaml_abs_path) +reader.init() reader.run_from_stdin() diff --git a/fleetrec/trainer/factory.py b/fleetrec/trainer/factory.py index 74369eec..c3af2f0d 100644 --- a/fleetrec/trainer/factory.py +++ b/fleetrec/trainer/factory.py @@ -35,9 +35,9 @@ class TrainerFactory(object): train_mode = envs.get_global_env("train.trainer") if train_mode == "SingleTraining": - trainer = SingleTrainer() + trainer = SingleTrainer(config) elif train_mode == "ClusterTraining": - trainer = ClusterTrainer() + trainer = ClusterTrainer(config) elif train_mode == "CtrTrainer": trainer = CtrPaddleTrainer(config) else: diff --git a/fleetrec/trainer/trainer.py b/fleetrec/trainer/trainer.py index b4b3dc57..0158202b 100755 --- a/fleetrec/trainer/trainer.py +++ b/fleetrec/trainer/trainer.py @@ -28,6 +28,7 @@ class Trainer(object): self._exe = fluid.Executor(self._place) self._exector_context = {} self._context = {'status': 'uninit', 'is_exit': False} + self._config = config def regist_context_processor(self, status_name, processor): """ diff --git a/fleetrec/trainer/transpiler_trainer.py b/fleetrec/trainer/transpiler_trainer.py index 976ce892..ebdfdefc 100644 --- a/fleetrec/trainer/transpiler_trainer.py +++ b/fleetrec/trainer/transpiler_trainer.py @@ -44,7 +44,7 @@ class TranspileTrainer(Trainer): reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) reader = os.path.join(abs_dir, '..', 'reader_implement.py') - pipe_cmd = "python {} {} {}".format(reader, reader_class, "TRAIN") + pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) train_data_path = envs.get_global_env("train_data_path", None, namespace) dataset = fluid.DatasetFactory().create_dataset() -- GitLab