提交 7f99ff03 编写于 作者: T tangwei

code clean

上级 33edc01a
......@@ -29,11 +29,7 @@ train:
epochs: 10
trainer: "SingleTraining"
strategy:
mode: "async"
reader:
mode: "dataset"
batch_size: 2
pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py"
train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train"
......@@ -56,8 +52,6 @@ train:
inference:
dirname: "models_for_inference"
epoch_interval: 4
feed_varnames: ["C1", "C2", "C3"]
fetch_varnames: "predict"
save_last: True
evaluate:
......
......@@ -31,6 +31,6 @@ from fleetrec.trainer.factory import TrainerFactory
if __name__ == "__main__":
abs_dir = os.path.dirname(os.path.abspath(__file__))
yaml = os.path.join(abs_dir, 'ctr-dnn_train_cluster.yaml')
yaml = os.path.join(abs_dir, 'ctr-dnn_train_single.yaml')
trainer = TrainerFactory.create(yaml)
trainer.run()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -14,13 +14,27 @@
from __future__ import print_function
import abc
import os
import paddle.fluid.incubate.data_generator as dg
import yaml
from fleetrec.utils import envs
class Reader(dg.MultiSlotDataGenerator):
__metaclass__ = abc.ABCMeta
def __init__(self, config):
super().__init__()
if os.path.exists(config) and os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
envs.set_global_envs(_config)
@abc.abstractmethod
def init(self):
pass
......
......@@ -16,8 +16,8 @@ import sys
from fleetrec.utils.envs import lazy_instance
if len(sys.argv) != 3:
raise ValueError("reader only accept two argument: initialized reader class name and TRAIN/EVALUATE")
if len(sys.argv) != 4:
raise ValueError("reader only accept 3 argument: 1. reader_class 2.train/evaluate 3.yaml_abs_path")
reader_package = sys.argv[1]
......@@ -26,6 +26,8 @@ if sys.argv[2] == "TRAIN":
else:
reader_name = "EvaluateReader"
yaml_abs_path = sys.argv[3]
reader_class = lazy_instance(reader_package, reader_name)
reader = reader_class()
reader = reader_class(yaml_abs_path)
reader.init()
reader.run_from_stdin()
......@@ -35,9 +35,9 @@ class TrainerFactory(object):
train_mode = envs.get_global_env("train.trainer")
if train_mode == "SingleTraining":
trainer = SingleTrainer()
trainer = SingleTrainer(config)
elif train_mode == "ClusterTraining":
trainer = ClusterTrainer()
trainer = ClusterTrainer(config)
elif train_mode == "CtrTrainer":
trainer = CtrPaddleTrainer(config)
else:
......
......@@ -28,6 +28,7 @@ class Trainer(object):
self._exe = fluid.Executor(self._place)
self._exector_context = {}
self._context = {'status': 'uninit', 'is_exit': False}
self._config = config
def regist_context_processor(self, status_name, processor):
"""
......
......@@ -44,7 +44,7 @@ class TranspileTrainer(Trainer):
reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '..', 'reader_implement.py')
pipe_cmd = "python {} {} {}".format(reader, reader_class, "TRAIN")
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config)
train_data_path = envs.get_global_env("train_data_path", None, namespace)
dataset = fluid.DatasetFactory().create_dataset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册