From 1e95361796e5e411c2f33da1298fac93e5f2fe9b Mon Sep 17 00:00:00 2001 From: tangwei Date: Tue, 31 Mar 2020 20:15:49 +0800 Subject: [PATCH] add ctr-dnn example --- examples/ctr-dnn_train.yaml | 65 +++++++ models/ctr_dnn/hyper_parameters.yaml | 41 ++++- models/ctr_dnn/model.py | 16 +- trainer/cluster_train_local.py | 0 ...ng_offline.py => cluster_train_offline.py} | 0 trainer/cluster_training_local.py | 77 -------- trainer/single_train.py | 168 ++++++++++++++++++ trainer/trainer.py | 20 ++- utils/envs.py | 8 +- 9 files changed, 301 insertions(+), 94 deletions(-) create mode 100644 examples/ctr-dnn_train.yaml create mode 100644 trainer/cluster_train_local.py rename trainer/{cluster_training_offline.py => cluster_train_offline.py} (100%) delete mode 100644 trainer/cluster_training_local.py create mode 100644 trainer/single_train.py diff --git a/examples/ctr-dnn_train.yaml b/examples/ctr-dnn_train.yaml new file mode 100644 index 00000000..302fc110 --- /dev/null +++ b/examples/ctr-dnn_train.yaml @@ -0,0 +1,65 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +train: + batch_size: 32 + threads: 12 + epochs: 10 + trainer: "SingleTraining" + + reader: + mode: "dataset" + pipe_command: "python reader.py dataset" + train_data_path: "raw_data" + + model: + models: "eleps.models.ctr_dnn.model.py" + hyper_parameters: + sparse_inputs_slots: 27, + sparse_feature_number: 1000001, + sparse_feature_dim: 8, + dense_input_dim: 13, + fc_sizes: [1024, 512, 32], + learning_rate: 0.001 + + save: + increment: + dirname: "models_for_increment" + epoch_interval: 2 + save_last: True + inference: + dirname: "models_for_inference" + epoch_interval: 4 + feed_varnames: ["C1", "C2", "C3"] + fetch_varnames: "predict" + save_last: True + +evaluate: + batch_size: 32 + train_thread_num: 12 + reader: "reader.py" + + diff --git a/models/ctr_dnn/hyper_parameters.yaml b/models/ctr_dnn/hyper_parameters.yaml index bd932ca9..10c24807 100644 --- a/models/ctr_dnn/hyper_parameters.yaml +++ b/models/ctr_dnn/hyper_parameters.yaml @@ -1,8 +1,33 @@ -{ - "sparse_inputs_slots": 27, - "sparse_feature_number": 1000001, - "sparse_feature_dim": 8, - "dense_input_dim": 13, - "fc_sizes": [400, 400, 40], - "learning_rate": 0.001 -} \ No newline at end of file +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +sparse_inputs_slots: 27, +sparse_feature_number: 1000001, +sparse_feature_dim: 8, +dense_input_dim: 13, +fc_sizes: [400, 400, 40], +learning_rate: 0.001 diff --git a/models/ctr_dnn/model.py b/models/ctr_dnn/model.py index 35d1a52a..cc3242d9 100644 --- a/models/ctr_dnn/model.py +++ b/models/ctr_dnn/model.py @@ -57,6 +57,12 @@ class Train(object): self.dense_input, self.dense_input_varname = dense_input() self.label_input, self.label_input_varname = label_input() + def input_vars(self): + return self.sparse_inputs + [self.dense_input] + [self.label_input] + + def input_varnames(self): + return [input.name for input in self.input_vars()] + def net(self): def embedding_layer(input): sparse_feature_number = envs.get_global_env("sparse_feature_number") @@ -101,22 +107,28 @@ class Train(object): self.predict = predict - def loss(self, predict): + def avg_loss(self, predict): cost = fluid.layers.cross_entropy(input=predict, label=self.label_input) avg_cost = fluid.layers.reduce_sum(cost) self.loss = avg_cost + return avg_cost - def metric(self): + def metrics(self): auc, batch_auc, _ = fluid.layers.auc(input=self.predict, label=self.label_input, num_thresholds=2 ** 12, slide_steps=20) + self.metrics = (auc, batch_auc) def optimizer(self): learning_rate = envs.get_global_env("learning_rate") optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True) return optimizer + def optimize(self): + optimizer = self.optimizer() + optimizer.minimize(self.loss) + class Evaluate(object): def input(self): diff --git a/trainer/cluster_train_local.py b/trainer/cluster_train_local.py new file mode 100644 index 00000000..e69de29b diff --git a/trainer/cluster_training_offline.py b/trainer/cluster_train_offline.py similarity index 100% rename from trainer/cluster_training_offline.py rename to trainer/cluster_train_offline.py diff --git a/trainer/cluster_training_local.py b/trainer/cluster_training_local.py deleted file mode 100644 index f90d7b18..00000000 --- a/trainer/cluster_training_local.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function -import os -import time -import numpy as np -import logging -import paddle.fluid as fluid -from network import CTR -from argument import params_args - -logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger("fluid") -logger.setLevel(logging.INFO) - - -def get_dataset(inputs, params): - dataset = fluid.DatasetFactory().create_dataset() - dataset.set_use_var(inputs) - dataset.set_pipe_command("python dataset_generator.py") - dataset.set_batch_size(params.batch_size) - dataset.set_thread(int(params.cpu_num)) - file_list = [ - str(params.train_files_path) + "/%s" % x - for x in os.listdir(params.train_files_path) - ] - dataset.set_filelist(file_list) - logger.info("file list: {}".format(file_list)) - return dataset - - -def train(params): - ctr_model = CTR() - inputs = ctr_model.input_data(params) - avg_cost, auc_var, batch_auc_var = ctr_model.net(inputs, params) - optimizer = fluid.optimizer.Adam(params.learning_rate) - optimizer.minimize(avg_cost) - fluid.default_main_program() - exe = fluid.Executor(fluid.CPUPlace()) - exe.run(fluid.default_startup_program()) - dataset = get_dataset(inputs, params) - - logger.info("Training Begin") - for epoch in range(params.epochs): - start_time = time.time() - exe.train_from_dataset(program=fluid.default_main_program(), - dataset=dataset, - fetch_list=[auc_var], - fetch_info=["Epoch {} auc ".format(epoch)], - print_period=100, - debug=False) - end_time = time.time() - logger.info("epoch %d finished, use time=%d\n" % - ((epoch), end_time - start_time)) - - if params.test: - model_path = (str(params.model_path) + "/" + "epoch_" + str(epoch)) - fluid.io.save_persistables(executor=exe, dirname=model_path) - - logger.info("Train Success!") - - -if __name__ == "__main__": - params = params_args() - train(params) \ No newline at end of file diff --git a/trainer/single_train.py b/trainer/single_train.py new file mode 100644 index 00000000..20804543 --- /dev/null +++ b/trainer/single_train.py @@ -0,0 +1,168 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Training use fluid with one node only. +""" + +from __future__ import print_function +import os +import time +import numpy as np +import logging +import paddle.fluid as fluid + +from .trainer import Trainer +from ..utils import envs + +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger("fluid") +logger.setLevel(logging.INFO) + + +def need_save(epoch_id, epoch_interval, is_last=False): + if is_last: + return True + + return epoch_id % epoch_interval == 0 + + +class SingleTrainer(Trainer): + + def __init__(self, config=None, yaml_file=None): + Trainer.__init__(self, config, yaml_file) + + self.exe = fluid.Executor(fluid.CPUPlace()) + + self.regist_context_processor('uninit', self.instance) + self.regist_context_processor('init_pass', self.init) + self.regist_context_processor('train_pass', self.train) + self.regist_context_processor('infer_pass', self.infer) + self.regist_context_processor('terminal_pass', self.terminal) + + def instance(self, context): + model_package = __import__(envs.get_global_env("train.model.models")) + train_model = getattr(model_package, 'Train') + + self.model = train_model() + + context['status'] = 'init_pass' + + def init(self, context): + self.model.input() + self.model.net() + self.model.loss() + self.metrics = self.model.metrics() + self.model.optimize() + + # run startup program at once + self.exe.run(fluid.default_startup_program()) + + context['status'] = 'train_pass' + + def train(self, context): + print("Need to be implement") + context['is_exit'] = True + + def infer(self, context): + print("Need to be implement") + context['is_exit'] = True + + def terminal(self, context): + context['is_exit'] = True + + +class SingleTrainerWithDataloader(SingleTrainer): + pass + + +class SingleTrainerWithDataset(SingleTrainer): + def _get_dataset(self, inputs, threads, batch_size, pipe_command, train_files_path): + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_use_var(inputs) + dataset.set_pipe_command(pipe_command) + dataset.set_batch_size(batch_size) + dataset.set_thread(threads) + file_list = [ + os.path.join(train_files_path, x) + for x in os.listdir(train_files_path) + ] + + dataset.set_filelist(file_list) + return dataset + + def save(self, epoch_id): + def save_inference_model(): + is_save_inference = envs.get_global_env("save.inference", False) + if not is_save_inference: + return + + save_interval = envs.get_global_env("save.inference.epoch_interval", 1) + if not need_save(epoch_id, save_interval, False): + return + + feed_varnames = envs.get_global_env("save.inference.feed_varnames", None) + fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None) + fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames] + dirname = envs.get_global_env("save.inference.dirname", None) + + assert dirname is not None + dirname = os.path.join(dirname, str(epoch_id)) + fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe) + + def save_persistables(): + is_save_increment = envs.get_global_env("save.increment", False) + if not is_save_increment: + return + + save_interval = envs.get_global_env("save.increment.epoch_interval", 1) + if not need_save(epoch_id, save_interval, False): + return + + dirname = envs.get_global_env("save.inference.dirname", None) + + assert dirname is not None + dirname = os.path.join(dirname, str(epoch_id)) + fluid.io.save_persistables(self.exe, dirname) + + is_save = envs.get_global_env("save", False) + + if not is_save: + return + + save_persistables() + save_inference_model() + + def train(self, context): + inputs = self.model.input_vars() + threads = envs.get_global_env("threads") + batch_size = envs.get_global_env("batch_size") + pipe_command = envs.get_global_env("pipe_command") + train_data_path = envs.get_global_env("train_data_path") + + dataset = self._get_dataset(inputs, threads, batch_size, pipe_command, train_data_path) + + epochs = envs.get_global_env("epochs") + + for i in range(epochs): + self.exe.train_from_dataset(program=fluid.default_main_program(), + dataset=dataset, + fetch_list=[self.metrics], + fetch_info=["epoch {} auc ".format(i)], + print_period=100) + context['status'] = 'infer_pass' + + +def infer(self, context): + context['status'] = 'terminal_pass' diff --git a/trainer/trainer.py b/trainer/trainer.py index e56ed0d7..cd8f971c 100755 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -14,6 +14,9 @@ import abc import time +import yaml + +from .. utils import envs class Trainer(object): @@ -21,9 +24,20 @@ class Trainer(object): """ __metaclass__ = abc.ABCMeta - def __init__(self, config): - """R - """ + def __init__(self, config=None, yaml_file=None): + + if not config and not yaml_file: + raise ValueError("config and yaml file have at least one not empty") + + if config and yaml_file: + print("config and yaml file are all assigned, will use yaml file: {}".format(yaml_file)) + + if yaml_file: + with open(yaml_file, "r") as rb: + config = yaml.load(rb.read()) + + envs.set_global_envs(config) + self._status_processor = {} self._context = {'status': 'uninit', 'is_exit': False} diff --git a/utils/envs.py b/utils/envs.py index 51e5b641..7bdfb988 100644 --- a/utils/envs.py +++ b/utils/envs.py @@ -24,17 +24,17 @@ def decode_value(v): return v -def set_global_envs(yaml, envs): +def set_global_envs(yaml): for k, v in yaml.items(): - envs[k] = encode_value(v) + os.environ[k] = encode_value(v) -def get_global_env(env_name): +def get_global_env(env_name, default_value=None): """ get os environment value """ if env_name not in os.environ: - raise ValueError("can not find config of {}".format(env_name)) + return default_value v = os.environ[env_name] return decode_value(v) -- GitLab