提交 7f99ff03 编写于 作者: T tangwei

code clean

上级 33edc01a
...@@ -29,11 +29,7 @@ train: ...@@ -29,11 +29,7 @@ train:
epochs: 10 epochs: 10
trainer: "SingleTraining" trainer: "SingleTraining"
strategy:
mode: "async"
reader: reader:
mode: "dataset"
batch_size: 2 batch_size: 2
pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py" pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py"
train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train" train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train"
...@@ -56,8 +52,6 @@ train: ...@@ -56,8 +52,6 @@ train:
inference: inference:
dirname: "models_for_inference" dirname: "models_for_inference"
epoch_interval: 4 epoch_interval: 4
feed_varnames: ["C1", "C2", "C3"]
fetch_varnames: "predict"
save_last: True save_last: True
evaluate: evaluate:
......
...@@ -31,6 +31,6 @@ from fleetrec.trainer.factory import TrainerFactory ...@@ -31,6 +31,6 @@ from fleetrec.trainer.factory import TrainerFactory
if __name__ == "__main__": if __name__ == "__main__":
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
yaml = os.path.join(abs_dir, 'ctr-dnn_train_cluster.yaml') yaml = os.path.join(abs_dir, 'ctr-dnn_train_single.yaml')
trainer = TrainerFactory.create(yaml) trainer = TrainerFactory.create(yaml)
trainer.run() trainer.run()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -14,13 +14,27 @@ ...@@ -14,13 +14,27 @@
from __future__ import print_function from __future__ import print_function
import abc import abc
import os
import paddle.fluid.incubate.data_generator as dg import paddle.fluid.incubate.data_generator as dg
import yaml
from fleetrec.utils import envs
class Reader(dg.MultiSlotDataGenerator): class Reader(dg.MultiSlotDataGenerator):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, config):
super().__init__()
if os.path.exists(config) and os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
envs.set_global_envs(_config)
@abc.abstractmethod @abc.abstractmethod
def init(self): def init(self):
pass pass
......
...@@ -16,8 +16,8 @@ import sys ...@@ -16,8 +16,8 @@ import sys
from fleetrec.utils.envs import lazy_instance from fleetrec.utils.envs import lazy_instance
if len(sys.argv) != 3: if len(sys.argv) != 4:
raise ValueError("reader only accept two argument: initialized reader class name and TRAIN/EVALUATE") raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate 3.yaml_abs_path")
reader_package = sys.argv[1] reader_package = sys.argv[1]
...@@ -26,6 +26,8 @@ if sys.argv[2] == "TRAIN": ...@@ -26,6 +26,8 @@ if sys.argv[2] == "TRAIN":
else: else:
reader_name = "EvaluateReader" reader_name = "EvaluateReader"
yaml_abs_path = sys.argv[3]
reader_class = lazy_instance(reader_package, reader_name) reader_class = lazy_instance(reader_package, reader_name)
reader = reader_class() reader = reader_class(yaml_abs_path)
reader.init()
reader.run_from_stdin() reader.run_from_stdin()
...@@ -35,9 +35,9 @@ class TrainerFactory(object): ...@@ -35,9 +35,9 @@ class TrainerFactory(object):
train_mode = envs.get_global_env("train.trainer") train_mode = envs.get_global_env("train.trainer")
if train_mode == "SingleTraining": if train_mode == "SingleTraining":
trainer = SingleTrainer() trainer = SingleTrainer(config)
elif train_mode == "ClusterTraining": elif train_mode == "ClusterTraining":
trainer = ClusterTrainer() trainer = ClusterTrainer(config)
elif train_mode == "CtrTrainer": elif train_mode == "CtrTrainer":
trainer = CtrPaddleTrainer(config) trainer = CtrPaddleTrainer(config)
else: else:
......
...@@ -28,6 +28,7 @@ class Trainer(object): ...@@ -28,6 +28,7 @@ class Trainer(object):
self._exe = fluid.Executor(self._place) self._exe = fluid.Executor(self._place)
self._exector_context = {} self._exector_context = {}
self._context = {'status': 'uninit', 'is_exit': False} self._context = {'status': 'uninit', 'is_exit': False}
self._config = config
def regist_context_processor(self, status_name, processor): def regist_context_processor(self, status_name, processor):
""" """
......
...@@ -44,7 +44,7 @@ class TranspileTrainer(Trainer): ...@@ -44,7 +44,7 @@ class TranspileTrainer(Trainer):
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '..', 'reader_implement.py') reader = os.path.join(abs_dir, '..', 'reader_implement.py')
pipe_cmd = "python {} {} {}".format(reader, reader_class, "TRAIN") pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config)
train_data_path = envs.get_global_env("train_data_path", None, namespace) train_data_path = envs.get_global_env("train_data_path", None, namespace)
dataset = fluid.DatasetFactory().create_dataset() dataset = fluid.DatasetFactory().create_dataset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册