From 3eba33699419894872104ed564cb1b7165898718 Mon Sep 17 00:00:00 2001 From: tangwei Date: Thu, 16 Apr 2020 14:15:18 +0800 Subject: [PATCH] fix import --- fleetrec/core/engine/local_mpi_engine.py | 6 +- fleetrec/core/trainers/ctr_coding_trainer.py | 139 +++++++++++++++++++ 2 files changed, 142 insertions(+), 3 deletions(-) create mode 100755 fleetrec/core/trainers/ctr_coding_trainer.py diff --git a/fleetrec/core/engine/local_mpi_engine.py b/fleetrec/core/engine/local_mpi_engine.py index c5c916ac..84ed6c30 100644 --- a/fleetrec/core/engine/local_mpi_engine.py +++ b/fleetrec/core/engine/local_mpi_engine.py @@ -34,8 +34,8 @@ class LocalMPIEngine(Engine): log_fns = [] factory = "fleetrec.core.factory" - mpi_cmd = "mpirun -npernode 2 -timestamp-output -tag-output".split(" ") - cmd = mpi_cmd.extend([sys.executable, "-u", "-m", factory, self.trainer]) + cmd = "mpirun -npernode 2 -timestamp-output -tag-output".split(" ") + cmd.extend([sys.executable, "-u", "-m", factory, self.trainer]) if logs_dir is not None: os.system("mkdir -p {}".format(logs_dir)) @@ -49,7 +49,7 @@ class LocalMPIEngine(Engine): for i in range(len(procs)): if len(log_fns) > 0: log_fns[i].close() - procs[i].terminate() + procs[i].wait() print("all workers and parameter servers already completed", file=sys.stderr) def run(self): diff --git a/fleetrec/core/trainers/ctr_coding_trainer.py b/fleetrec/core/trainers/ctr_coding_trainer.py new file mode 100755 index 00000000..d8751c30 --- /dev/null +++ b/fleetrec/core/trainers/ctr_coding_trainer.py @@ -0,0 +1,139 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import time +import json +import datetime +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet +from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker + +from fleetrec.core.utils import envs +from fleetrec.core.trainer import Trainer + + +class CtrPaddleTrainer(Trainer): + """R + """ + + def __init__(self, config): + """R + """ + Trainer.__init__(self, config) + + self.global_config = config + self._metrics = {} + self.processor_register() + + def processor_register(self): + role = MPISymetricRoleMaker() + fleet.init(role) + + if fleet.is_server(): + self.regist_context_processor('uninit', self.instance) + self.regist_context_processor('init_pass', self.init) + self.regist_context_processor('server_pass', self.server) + else: + self.regist_context_processor('uninit', self.instance) + self.regist_context_processor('init_pass', self.init) + self.regist_context_processor('train_pass', self.train) + self.regist_context_processor('terminal_pass', self.terminal) + + def _get_dataset(self): + namespace = "train.reader" + + inputs = self.model.get_inputs() + threads = envs.get_global_env("train.threads", None) + batch_size = envs.get_global_env("batch_size", None, namespace) + reader_class = envs.get_global_env("class", None, namespace) + abs_dir = os.path.dirname(os.path.abspath(__file__)) + reader = os.path.join(abs_dir, '../utils', 'reader_instance.py') + pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) + train_data_path = envs.get_global_env("train_data_path", None, namespace) + + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_use_var(inputs) + dataset.set_pipe_command(pipe_cmd) + dataset.set_batch_size(batch_size) + dataset.set_thread(threads) + file_list = [ + os.path.join(train_data_path, x) + for x in os.listdir(train_data_path) + ] + + dataset.set_filelist(file_list) + return dataset + + def instance(self, context): + models = envs.get_global_env("train.model.models") + model_class = envs.lazy_instance(models, "Model") + self.model = model_class(None) + context['status'] = 'init_pass' + + def init(self, context): + """R + """ + self.model.train_net() + optimizer = self.model.optimizer() + + optimizer = fleet.distributed_optimizer(optimizer, strategy={"use_cvm": False}) + optimizer.minimize(self.model.get_cost_op()) + + if fleet.is_server(): + context['status'] = 'server_pass' + else: + self.fetch_vars = [] + self.fetch_alias = [] + self.fetch_period = self.model.get_fetch_period() + + metrics = self.model.get_metrics() + if metrics: + self.fetch_vars = metrics.values() + self.fetch_alias = metrics.keys() + context['status'] = 'train_pass' + + def server(self, context): + fleet.run_server() + context['is_exit'] = True + + def train(self, context): + self._exe.run(fluid.default_startup_program()) + fleet.init_worker() + + dataset = self._get_dataset() + + shuf = np.array([fleet.worker_index()]) + gs = shuf * 0 + fleet._role_maker._node_type_comm.Allreduce(shuf, gs) + + print("trainer id: {}, trainers: {}, gs: {}".format(fleet.worker_index(), fleet.worker_num(), gs)) + + epochs = envs.get_global_env("train.epochs") + + for i in range(epochs): + self._exe.train_from_dataset(program=fluid.default_main_program(), + dataset=dataset, + fetch_list=self.fetch_vars, + fetch_info=self.fetch_alias, + print_period=self.fetch_period) + + context['status'] = 'terminal_pass' + fleet.stop_worker() + + def terminal(self, context): + print("terminal ended.") + context['is_exit'] = True -- GitLab