From e57ed51628534d9ace5f07f2f009362804a971fa Mon Sep 17 00:00:00 2001 From: tangwei Date: Mon, 18 May 2020 20:02:16 +0800 Subject: [PATCH] add online trainning trainer --- core/trainers/online_learning_trainer.py | 185 +++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100755 core/trainers/online_learning_trainer.py diff --git a/core/trainers/online_learning_trainer.py b/core/trainers/online_learning_trainer.py new file mode 100755 index 00000000..1924b966 --- /dev/null +++ b/core/trainers/online_learning_trainer.py @@ -0,0 +1,185 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Training use fluid with one node only. +""" + +from __future__ import print_function + +import os +import time +import datetime + +import paddle.fluid as fluid +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory +from paddle.fluid.incubate.fleet.base.role_maker import PaddleCloudRoleMaker + +from paddlerec.core.utils import envs +from paddlerec.core.trainers.transpiler_trainer import TranspileTrainer + + +class ClusterTrainer(TranspileTrainer): + def processor_register(self): + role = PaddleCloudRoleMaker() + fleet.init(role) + + if fleet.is_server(): + self.regist_context_processor('uninit', self.instance) + self.regist_context_processor('init_pass', self.init) + self.regist_context_processor('server_pass', self.server) + else: + self.regist_context_processor('uninit', self.instance) + self.regist_context_processor('init_pass', self.init) + self.regist_context_processor('startup_pass', self.startup) + if envs.get_platform() == "LINUX" and envs.get_global_env("dataset_class", None, "train.reader") != "DataLoader": + self.regist_context_processor('train_pass', self.dataset_train) + else: + self.regist_context_processor( + 'train_pass', self.dataloader_train) + self.regist_context_processor('infer_pass', self.infer) + self.regist_context_processor('terminal_pass', self.terminal) + + def build_strategy(self): + mode = envs.get_runtime_environ("train.trainer.strategy") + assert mode in ["async", "geo", "sync", "half_async"] + + strategy = None + + if mode == "async": + strategy = StrategyFactory.create_async_strategy() + elif mode == "geo": + push_num = envs.get_global_env("train.strategy.mode.push_num", 100) + strategy = StrategyFactory.create_geo_strategy(push_num) + elif mode == "sync": + strategy = StrategyFactory.create_sync_strategy() + elif mode == "half_async": + strategy = StrategyFactory.create_half_async_strategy() + + assert strategy is not None + + self.strategy = strategy + return strategy + + def init(self, context): + self.model.train_net() + optimizer = self.model.optimizer() + strategy = self.build_strategy() + optimizer = fleet.distributed_optimizer(optimizer, strategy) + optimizer.minimize(self.model.get_cost_op()) + + if fleet.is_server(): + context['status'] = 'server_pass' + else: + self.fetch_vars = [] + self.fetch_alias = [] + self.fetch_period = self.model.get_fetch_period() + + metrics = self.model.get_metrics() + if metrics: + self.fetch_vars = metrics.values() + self.fetch_alias = metrics.keys() + context['status'] = 'startup_pass' + + def server(self, context): + fleet.init_server() + fleet.run_server() + context['is_exit'] = True + + def startup(self, context): + self._exe.run(fleet.startup_program) + context['status'] = 'train_pass' + + def dataloader_train(self, context): + print("online learning can only support LINUX only") + context['status'] = 'terminal_pass' + + def _get_dataset(self, state="TRAIN", hour=None): + if state == "TRAIN": + inputs = self.model.get_inputs() + namespace = "train.reader" + train_data_path = envs.get_global_env( + "train_data_path", None, namespace) + else: + inputs = self.model.get_infer_inputs() + namespace = "evaluate.reader" + train_data_path = envs.get_global_env( + "test_data_path", None, namespace) + + threads = int(envs.get_runtime_environ("train.trainer.threads")) + batch_size = envs.get_global_env("batch_size", None, namespace) + reader_class = envs.get_global_env("class", None, namespace) + abs_dir = os.path.dirname(os.path.abspath(__file__)) + reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') + pipe_cmd = "python {} {} {} {}".format( + reader, reader_class, state, self._config_yaml) + + if train_data_path.startswith("paddlerec::"): + package_base = envs.get_runtime_environ("PACKAGE_BASE") + assert package_base is not None + train_data_path = os.path.join( + package_base, train_data_path.split("::")[1]) + + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_use_var(inputs) + dataset.set_pipe_command(pipe_cmd) + dataset.set_batch_size(batch_size) + dataset.set_thread(threads) + + if hour is not None: + train_data_path = os.path.join(train_data_path, hour) + + file_list = [ + os.path.join(train_data_path, x) + for x in os.listdir(train_data_path) + ] + + self.files = file_list + dataset.set_filelist(self.files) + return dataset + + def dataset_train(self, context): + fleet.init_worker() + + days = envs.get_global_env("train.days") + begin_day = datetime.datetime.strptime("begin_day_d", '%Y%m%d') + + for day in range(days): + for hour in range(24): + day = begin_day + datetime.timedelta(days=day, hours=hour) + day_s = day.strftime('%Y%m%d/%H') + i = day.strftime('%Y%m%d_%H') + + dataset = self._get_dataset(hour=day_s) + ins = self._get_dataset_ins() + + begin_time = time.time() + self._exe.train_from_dataset(program=fluid.default_main_program(), + dataset=dataset, + fetch_list=self.fetch_vars, + fetch_info=self.fetch_alias, + print_period=self.fetch_period) + end_time = time.time() + times = end_time-begin_time + print("epoch {} using time {}, speed {:.2f} lines/s".format(i, times, ins/times)) + self.save(i, "train", is_fleet=True) + + fleet.stop_worker() + context['status'] = 'infer_pass' + + def terminal(self, context): + for model in self.increment_models: + print("epoch :{}, dir: {}".format(model[0], model[1])) + context['is_exit'] = True -- GitLab