diff --git a/example/wide_and_deep/README.md b/example/wide_and_deep/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f770297dd01faadef848a6457ce3ffef80526f5c --- /dev/null +++ b/example/wide_and_deep/README.md @@ -0,0 +1,93 @@ +recommendation Model +## Overview +This is an implementation of WideDeep as described in the [Wide & Deep Learning for Recommender System](https://arxiv.org/pdf/1606.07792.pdf) paper. + +WideDeep model jointly trained wide linear models and deep neural network, which combined the benefits of memorization and generalization for recommender systems. + +## Dataset +The [Criteo datasets](http://labs.criteo.com/2014/02/download-kaggle-display-advertising-challenge-dataset/) are used for model training and evaluation. + +## Running Code + +### Download and preprocess dataset +To download the dataset, please install Pandas package first. Then issue the following command: +``` +bash download.sh +``` + +### Code Structure +The entire code structure is as following: +``` +|--- wide_and_deep/ + train_and_test.py "Entrance of Wide&Deep model training and evaluation" + test.py "Entrance of Wide&Deep model evaluation" + train.py "Entrance of Wide&Deep model training" + train_and_test_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation" + |--- src/ "entrance of training and evaluation" + config.py "parameters configuration" + dataset.py "Dataset loader class" + WideDeep.py "Model structure" + callbacks.py "Callback class for training and evaluation" + metrics.py "Metric class" +``` + +### Train and evaluate model +To train and evaluate the model, issue the following command: +``` +python train_and_test.py +``` +Arguments: + * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. + * `--epochs`: Total train epochs. + * `--batch_size`: Training batch size. + * `--eval_batch_size`: Eval batch size. + * `--field_size`: The number of features. + * `--vocab_size`: The total features of dataset. + * `--emb_dim`: The dense embedding dimension of sparse feature. + * `--deep_layers_dim`: The dimension of all deep layers. + * `--deep_layers_act`: The activation of all deep layers. + * `--keep_prob`: The rate to keep in dropout layer. + * `--ckpt_path`:The location of the checkpoint file. + * `--eval_file_name` : Eval output file. + * `--loss_file_name` : Loss output file. + +To train the model, issue the following command: +``` +python train.py +``` +Arguments: + * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. + * `--epochs`: Total train epochs. + * `--batch_size`: Training batch size. + * `--eval_batch_size`: Eval batch size. + * `--field_size`: The number of features. + * `--vocab_size`: The total features of dataset. + * `--emb_dim`: The dense embedding dimension of sparse feature. + * `--deep_layers_dim`: The dimension of all deep layers. + * `--deep_layers_act`: The activation of all deep layers. + * `--keep_prob`: The rate to keep in dropout layer. + * `--ckpt_path`:The location of the checkpoint file. + * `--eval_file_name` : Eval output file. + * `--loss_file_name` : Loss output file. + +To evaluate the model, issue the following command: +``` +python test.py +``` +Arguments: + * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. + * `--epochs`: Total train epochs. + * `--batch_size`: Training batch size. + * `--eval_batch_size`: Eval batch size. + * `--field_size`: The number of features. + * `--vocab_size`: The total features of dataset. + * `--emb_dim`: The dense embedding dimension of sparse feature. + * `--deep_layers_dim`: The dimension of all deep layers. + * `--deep_layers_act`: The activation of all deep layers. + * `--keep_prob`: The rate to keep in dropout layer. + * `--ckpt_path`:The location of the checkpoint file. + * `--eval_file_name` : Eval output file. + * `--loss_file_name` : Loss output file. + +There are other arguments about models and training process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions. + diff --git a/example/wide_and_deep/src/__init__.py b/example/wide_and_deep/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/wide_and_deep/src/callbacks.py b/example/wide_and_deep/src/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..1ceefc115a646dc4ca1ab019335695c2ab5d30b6 --- /dev/null +++ b/example/wide_and_deep/src/callbacks.py @@ -0,0 +1,104 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +callbacks +""" +import time +from mindspore.train.callback import Callback +from mindspore import context + +def add_write(file_path, out_str): + """ + add lines to the file + """ + with open(file_path, 'a+', encoding="utf-8") as file_out: + file_out.write(out_str + "\n") + + +class LossCallBack(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF, terminate the training. + + Note: + If per_print_times is 0, do NOT print loss. + + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, config=None, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("per_print_times must be in and >= 0.") + self._per_print_times = per_print_times + self.config = config + + def step_end(self, run_context): + cb_params = run_context.original_args() + wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + cur_num = cb_params.cur_step_num + print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss) + + # raise ValueError + if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and config is not None: + loss_file = open(self.config.loss_file_name, "a+") + loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % + (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) + loss_file.write("\n") + loss_file.close() + print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % + (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) + + +class EvalCallBack(Callback): + """ + Monitor the loss in evaluating. + + If the loss is NAN or INF, terminate evaluating. + + Note: + If per_print_times is 0, do NOT print loss. + + Args: + print_per_step (int): Print loss every times. Default: 1. + """ + def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): + super(EvalCallBack, self).__init__() + if not isinstance(print_per_step, int) or print_per_step < 0: + raise ValueError("print_per_step must be int and >= 0.") + self.print_per_step = print_per_step + self.model = model + self.eval_dataset = eval_dataset + self.aucMetric = auc_metric + self.aucMetric.clear() + self.eval_file_name = config.eval_file_name + + def epoch_name(self, run_context): + """ + epoch name + """ + self.aucMetric.clear() + context.set_auto_parallel_context(strategy_ckpt_save_file="", + strategy_ckpt_load_file="./strategy_train.ckpt") + start_time = time.time() + out = self.model.eval(self.eval_dataset) + end_time = time.time() + eval_time = int(end_time - start_time) + + time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime()) + out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time) + print(out_str) + add_write(self.eval_file_name, out_str) diff --git a/example/wide_and_deep/src/config.py b/example/wide_and_deep/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..707750c97a4a97375a0add31536fcddbd0a1670c --- /dev/null +++ b/example/wide_and_deep/src/config.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" config. """ +import argparse + + +def argparse_init(): + """ + argparse_init + """ + parser = argparse.ArgumentParser(description='WideDeep') + parser.add_argument("--data_path", type=str, default="./test_raw_data/") + parser.add_argument("--epochs", type=int, default=15) + parser.add_argument("--batch_size", type=int, default=10000) + parser.add_argument("--eval_batch_size", type=int, default=15) + parser.add_argument("--field_size", type=int, default=39) + parser.add_argument("--vocab_size", type=int, default=184965) + parser.add_argument("--emb_dim", type=int, default=80) + parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) + parser.add_argument("--deep_layer_act", type=str, default='relu') + parser.add_argument("--keep_prob", type=float, default=1.0) + + parser.add_argument("--output_path", type=str, default="./output/") + parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") + parser.add_argument("--eval_file_name", type=str, default="eval.log") + parser.add_argument("--loss_file_name", type=str, default="loss.log") + return parser + + +class WideDeepConfig(): + """ + WideDeepConfig + """ + def __init__(self): + self.data_path = "./test_raw_data/" + self.epochs = 15 + self.batch_size = 10000 + self.eval_batch_size = 10000 + self.field_size = 39 + self.vocab_size = 184965 + self.emb_dim = 80 + self.deep_layer_dim = [1024, 512, 256, 128] + self.deep_layer_act = 'relu' + self.weight_bias_init = ['normal', 'normal'] + self.emb_init = 'normal' + self.init_args = [-0.01, 0.01] + self.dropout_flag = False + self.keep_prob = 1.0 + self.l2_coef = 8e-5 + + self.output_path = "./output" + self.eval_file_name = "eval.log" + self.loss_file_name = "loss.log" + self.ckpt_path = "./checkpoints/" + + def argparse_init(self): + """ + argparse_init + """ + parser = argparse_init() + args, _ = parser.parse_known_args() + self.data_path = args.data_path + self.epochs = args.epochs + self.batch_size = args.batch_size + self.eval_batch_size = args.eval_batch_size + self.field_size = args.field_size + self.vocab_size = args.vocab_size + self.emb_dim = args.emb_dim + self.deep_layer_dim = args.deep_layer_dim + self.deep_layer_act = args.deep_layer_act + self.keep_prob = args.keep_prob + self.weight_bias_init = ['normal', 'normal'] + self.emb_init = 'normal' + self.init_args = [-0.01, 0.01] + self.dropout_flag = False + self.l2_coef = 8e-5 + + self.output_path = args.output_path + self.eval_file_name = args.eval_file_name + self.loss_file_name = args.loss_file_name + self.ckpt_path = args.ckpt_path diff --git a/example/wide_and_deep/src/datasets.py b/example/wide_and_deep/src/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..775dd7ca54b5d12ab6b76c3157dd8952c222e61f --- /dev/null +++ b/example/wide_and_deep/src/datasets.py @@ -0,0 +1,207 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_imagenet.""" + + +import os +import math +import numpy as np +import pandas as pd +import mindspore.dataset.engine as de +import mindspore.common.dtype as mstype + + +class H5Dataset(): + """ + H5DataSet + """ + input_length = 39 + + def __init__(self, data_path, train_mode=True, train_num_of_parts=21, + test_num_of_parts=3): + self._hdf_data_dir = data_path + self._is_training = train_mode + + if self._is_training: + self._file_prefix = 'train' + self._num_of_parts = train_num_of_parts + else: + self._file_prefix = 'test' + self._num_of_parts = test_num_of_parts + + self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix, + self._num_of_parts) + print("data_size: {}".format(self.data_size)) + + def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts): + size = 0 + for part in range(num_of_parts): + _y = pd.read_hdf(os.path.join(hdf_data_dir, + file_prefix + '_output_part_' + str( + part) + '.h5')) + size += _y.shape[0] + return size + + def _iterate_hdf_files_(self, num_of_parts=None, + shuffle_block=False): + """ + iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts + from the beginning, thus the data stream will never stop + :param train_mode: True or false,false is eval_mode, + this file iterator will go through the train set + :param num_of_parts: number of files + :param shuffle_block: shuffle block files at every round + :return: input_hdf_file_name, output_hdf_file_name, finish_flag + """ + parts = np.arange(num_of_parts) + while True: + if shuffle_block: + for _ in range(int(shuffle_block)): + np.random.shuffle(parts) + for i, p in enumerate(parts): + yield os.path.join(self._hdf_data_dir, + self._file_prefix + '_input_part_' + str( + p) + '.h5'), \ + os.path.join(self._hdf_data_dir, + self._file_prefix + '_output_part_' + str( + p) + '.h5'), i + 1 == len(parts) + + def _generator(self, X, y, batch_size, shuffle=True): + """ + should be accessed only in private + :param X: + :param y: + :param batch_size: + :param shuffle: + :return: + """ + number_of_batches = np.ceil(1. * X.shape[0] / batch_size) + counter = 0 + finished = False + sample_index = np.arange(X.shape[0]) + if shuffle: + for _ in range(int(shuffle)): + np.random.shuffle(sample_index) + assert X.shape[0] > 0 + while True: + batch_index = sample_index[ + batch_size * counter: batch_size * (counter + 1)] + X_batch = X[batch_index] + y_batch = y[batch_index] + counter += 1 + yield X_batch, y_batch, finished + if counter == number_of_batches: + counter = 0 + finished = True + + def batch_generator(self, batch_size=1000, + random_sample=False, shuffle_block=False): + """ + :param train_mode: True or false,false is eval_mode, + :param batch_size + :param num_of_parts: number of files + :param random_sample: if True, will shuffle + :param shuffle_block: shuffle file blocks at every round + :return: + """ + + for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts, + shuffle_block): + start = stop = None + X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values + y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values + data_gen = self._generator(X_all, y_all, batch_size, + shuffle=random_sample) + finished = False + + while not finished: + X, y, finished = data_gen.__next__() + X_id = X[:, 0:self.input_length] + X_va = X[:, self.input_length:] + yield np.array(X_id.astype(dtype=np.int32)), np.array( + X_va.astype(dtype=np.float32)), np.array( + y.astype(dtype=np.float32)) + + +def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000): + """ + get_h5_dataset + """ + data_para = { + 'batch_size': batch_size, + } + if train_mode: + data_para['random_sample'] = True + data_para['shuffle_block'] = True + + h5_dataset = H5Dataset(data_path=data_dir, train_mode=train_mode) + numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size) + + def _iter_h5_data(): + train_eval_gen = h5_dataset.batch_generator(**data_para) + for _ in range(0, numbers_of_batch, 1): + yield train_eval_gen.__next__() + + ds = de.GeneratorDataset(_iter_h5_data(), ["ids", "weights", "labels"]) + ds.set_dataset_size(numbers_of_batch) + ds = ds.repeat(epochs) + return ds + + +def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000, + line_per_sample=1000, rank_size=None, rank_id=None): + """ + get_tf_dataset + """ + dataset_files = [] + file_prefix_name = 'train' if train_mode else 'test' + shuffle = train_mode + for (dirpath, _, filenames) in os.walk(data_dir): + for filename in filenames: + if file_prefix_name in filename and "tfrecord" in filename: + dataset_files.append(os.path.join(dirpath, filename)) + schema = de.Schema() + schema.add_column('feat_ids', de_type=mstype.int32) + schema.add_column('feat_vals', de_type=mstype.float32) + schema.add_column('label', de_type=mstype.float32) + if rank_size is not None and rank_id is not None: + ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8, + num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True) + else: + ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8) + ds = ds.batch(int(batch_size / line_per_sample), + drop_remainder=True) + ds = ds.map(operations=(lambda x, y, z: ( + np.array(x).flatten().reshape(batch_size, 39), + np.array(y).flatten().reshape(batch_size, 39), + np.array(z).flatten().reshape(batch_size, 1))), + input_columns=['feat_ids', 'feat_vals', 'label'], + columns_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8) + #if train_mode: + ds = ds.repeat(epochs) + return ds + + +def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000, + is_tf_dataset=True, line_per_sample=1000, rank_size=None, rank_id=None): + """ + create_dataset + """ + if is_tf_dataset: + return _get_tf_dataset(data_dir, train_mode, epochs, batch_size, + line_per_sample, rank_size=rank_size, rank_id=rank_id) + if rank_size > 1: + raise RuntimeError("please use tfrecord dataset.") + return _get_h5_dataset(data_dir, train_mode, epochs, batch_size) diff --git a/example/wide_and_deep/src/metrics.py b/example/wide_and_deep/src/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..277d6744dc92433375413fdae25426ec5245c296 --- /dev/null +++ b/example/wide_and_deep/src/metrics.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Area under cure metric +""" + +from mindspore.nn.metrics import Metric +from sklearn.metrics import roc_auc_score + +class AUCMetric(Metric): + """ + Area under cure metric + """ + + def __init__(self): + super(AUCMetric, self).__init__() + self.clear() + + def clear(self): + """Clear the internal evaluation result.""" + self.true_labels = [] + self.pred_probs = [] + + def update(self, *inputs): # inputs + all_predict = inputs[1].asnumpy() # predict + all_label = inputs[2].asnumpy() # label + self.true_labels.extend(all_label.flatten().tolist()) + self.pred_probs.extend(all_predict.flatten().tolist()) + + def eval(self): + if len(self.true_labels) != len(self.pred_probs): + raise RuntimeError( + 'true_labels.size is not equal to pred_probs.size()') + + auc = roc_auc_score(self.true_labels, self.pred_probs) + print("====" * 20 + " auc_metric end") + print("====" * 20 + " auc: {}".format(auc)) + return auc diff --git a/example/wide_and_deep/src/process_data.py b/example/wide_and_deep/src/process_data.py new file mode 100644 index 0000000000000000000000000000000000000000..37b38b0bbbbed76530318a7693eabbf1e2dcd366 --- /dev/null +++ b/example/wide_and_deep/src/process_data.py @@ -0,0 +1,268 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Criteo data process +""" + +import os +import pickle +import collections +import argparse + +import numpy as np +import pandas as pd + +TRAIN_LINE_COUNT = 45840617 +TEST_LINE_COUNT = 6042135 + +class CriteoStatsDict(): + """create data dict""" + def __init__(self): + self.field_size = 39 # value_1-13; cat_1-26; + self.val_cols = ["val_{}".format(i+1) for i in range(13)] + self.cat_cols = ["cat_{}".format(i+1) for i in range(26)] + # + self.val_min_dict = {col: 0 for col in self.val_cols} + self.val_max_dict = {col: 0 for col in self.val_cols} + self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols} + # + self.oov_prefix = "OOV_" + + self.cat2id_dict = {} + self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)}) + self.cat2id_dict.update({self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)}) + # + def stats_vals(self, val_list): + """vals status""" + assert len(val_list) == len(self.val_cols) + def map_max_min(i, val): + key = self.val_cols[i] + if val != "": + if float(val) > self.val_max_dict[key]: + self.val_max_dict[key] = float(val) + if float(val) < self.val_min_dict[key]: + self.val_min_dict[key] = float(val) + # + for i, val in enumerate(val_list): + map_max_min(i, val) + # + def stats_cats(self, cat_list): + assert len(cat_list) == len(self.cat_cols) + def map_cat_count(i, cat): + key = self.cat_cols[i] + self.cat_count_dict[key][cat] += 1 + # + for i, cat in enumerate(cat_list): + map_cat_count(i, cat) + # + def save_dict(self, output_path, prefix=""): + with open(os.path.join(output_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(self.val_max_dict, file_wrt) + with open(os.path.join(output_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(self.val_min_dict, file_wrt) + with open(os.path.join(output_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(self.cat_count_dict, file_wrt) + # + def load_dict(self, dict_path, prefix=""): + with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt: + self.val_max_dict = pickle.load(file_wrt) + with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt: + self.val_min_dict = pickle.load(file_wrt) + with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt: + self.cat_count_dict = pickle.load(file_wrt) + print("val_max_dict.items()[:50]: {}".format(list(self.val_max_dict.items()))) + print("val_min_dict.items()[:50]: {}".format(list(self.val_min_dict.items()))) + # + # + def get_cat2id(self, threshold=100): + """get cat to id""" + # before_all_count = 0 + # after_all_count = 0 + for key, cat_count_d in self.cat_count_dict.items(): + new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items())) + for cat_str, _ in new_cat_count_d.items(): + self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict) + # print("before_all_count: {}".format(before_all_count)) # before_all_count: 33762577 + # print("after_all_count: {}".format(after_all_count)) # after_all_count: 184926 + print("cat2id_dict.size: {}".format(len(self.cat2id_dict))) + print("cat2id_dict.items()[:50]: {}".format(self.cat2id_dict.items()[:50])) + # + def map_cat2id(self, values, cats): + """map cat to id""" + def minmax_sclae_value(i, val): + # min_v = float(self.val_min_dict["val_{}".format(i+1)]) + max_v = float(self.val_max_dict["val_{}".format(i + 1)]) + # return (float(val) - min_v) * 1.0 / (max_v - min_v) + return float(val) * 1.0 / max_v + + id_list = [] + weight_list = [] + for i, val in enumerate(values): + if val == "": + id_list.append(i) + weight_list.append(0) + else: + key = "val_{}".format(i + 1) + id_list.append(self.cat2id_dict[key]) + weight_list.append(minmax_sclae_value(i, float(val))) + # + for i, cat_str in enumerate(cats): + key = "cat_{}".format(i + 1) + "_" + cat_str + if key in self.cat2id_dict: + id_list.append(self.cat2id_dict[key]) + else: + id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)]) + weight_list.append(1.0) + return id_list, weight_list + # + + + +def mkdir_path(file_path): + if not os.path.exists(file_path): + os.makedirs(file_path) + # + +def statsdata(data_file_path, output_path, criteo_stats): + """data status""" + with open(data_file_path, encoding="utf-8") as file_in: + errorline_list = [] + count = 0 + for line in file_in: + count += 1 + line = line.strip("\n") + items = line.strip("\t") + if len(items) != 40: + errorline_list.append(count) + print("line: {}".format(line)) + continue + if count % 1000000 == 0: + print("Have handle {}w lines.".format(count//10000)) + # if count % 5000000 == 0: + # print("Have handle {}w lines.".format(count//10000)) + # label = items[0] + values = items[1:14] + cats = items[14:] + assert len(values) == 13, "value.size: {}".format(len(values)) + assert len(cats) == 26, "cat.size: {}".format(len(cats)) + criteo_stats.stats_vals(values) + criteo_stats.stats_cats(cats) + criteo_stats.save_dict(output_path) + # + + +def add_write(file_path, wr_str): + with open(file_path, "a", encoding="utf-8") as file_out: + file_out.write(wr_str + "\n") +# + + +def random_split_trans2h5(in_file_path, output_path, criteo_stats, part_rows=2000000, test_size=0.1, seed=2020): + """random split trans2h5""" + test_size = int(TRAIN_LINE_COUNT * test_size) + # train_size = TRAIN_LINE_COUNT - test_size + all_indices = [i for i in range(TRAIN_LINE_COUNT)] + np.random.seed(seed) + np.random.shuffle(all_indices) + print("all_indices.size: {}".format(len(all_indices))) + # lines_count_dict = collections.defaultdict(int) + test_indices_set = set(all_indices[:test_size]) + print("test_indices_set.size: {}".format(len(test_indices_set))) + print("------" * 10 + "\n" * 2) + + train_feature_file_name = os.path.join(output_path, "train_input_part_{}.h5") + train_label_file_name = os.path.join(output_path, "train_output_part_{}.h5") + test_feature_file_name = os.path.join(output_path, "test_input_part_{}.h5") + test_label_file_name = os.path.join(output_path, "test_input_part_{}.h5") + train_feature_list = [] + train_label_list = [] + test_feature_list = [] + test_label_list = [] + with open(in_file_path, encoding="utf-8") as file_in: + count = 0 + train_part_number = 0 + test_part_number = 0 + for i, line in enumerate(file_in): + count += 1 + if count % 1000000 == 0: + print("Have handle {}w lines.".format(count // 10000)) + line = line.strip("\n") + items = line.split("\t") + if len(items) != 40: + continue + label = float(items[0]) + values = items[1:14] + cats = items[14:] + assert len(values) == 13, "value.size: {}".format(len(values)) + assert len(cats) == 26, "cat.size: {}".format(len(cats)) + ids, wts = criteo_stats.map_cat2id(values, cats) + if i not in test_indices_set: + train_feature_list.append(ids + wts) + train_label_list.append(label) + else: + test_feature_list.append(ids + wts) + test_label_list.append(label) + if train_label_list and (len(train_label_list) % part_rows == 0): + pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number), + key="fixed") + pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number), + key="fixed") + train_feature_list = [] + train_label_list = [] + train_part_number += 1 + if test_label_list and (len(test_label_list) % part_rows == 0): + pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number), + key="fixed") + pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), + key="fixed") + test_feature_list = [] + test_label_list = [] + test_part_number += 1 + # + if train_label_list: + pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number), + key="fixed") + pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number), + key="fixed") + if test_label_list: + pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number), + key="fixed") + pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), + key="fixed") +# + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Get and Process datasets") + parser.add_argument("--raw_data_path", default="/opt/npu/data/origin_criteo_data/", help="The path to save dataset") + parser.add_argument("--output_path", default="/opt/npu/data/origin_criteo_data/h5_data/", + help="The path to save dataset") + args, _ = parser.parse_known_args() + base_path = args.raw_data_path + criteo_stat = CriteoStatsDict() + # step 1, stats the vocab and normalize value + datafile_path = base_path + "train_small.txt" + stats_out_path = base_path + "stats_dict/" + mkdir_path(stats_out_path) + statsdata(datafile_path, stats_out_path, criteo_stat) + print("------" * 10) + criteo_stat.load_dict(dict_path=stats_out_path, prefix="") + criteo_stat.get_cat2id(threshold=100) + # step 2, transform data trans2h5; version 2: np.random.shuffle + infile_path = base_path + "train_small.txt" + mkdir_path(args.output_path) + random_split_trans2h5(infile_path, args.output_path, criteo_stat, part_rows=2000000, test_size=0.1, seed=2020) diff --git a/example/wide_and_deep/src/wide_and_deep.py b/example/wide_and_deep/src/wide_and_deep.py new file mode 100644 index 0000000000000000000000000000000000000000..7772431ab330428163b68d7d8f512b185dfcf398 --- /dev/null +++ b/example/wide_and_deep/src/wide_and_deep.py @@ -0,0 +1,311 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""wide and deep model""" +from mindspore import nn +from mindspore import Tensor, Parameter, ParameterTuple +import mindspore.common.dtype as mstype +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.ops import operations as P +# from mindspore.nn import Dropout +from mindspore.nn.optim import Adam, FTRL +# from mindspore.nn.metrics import Metric +from mindspore.common.initializer import Uniform, initializer +# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +import numpy as np + +np_type = np.float32 +ms_type = mstype.float32 + + +def init_method(method, shape, name, max_val=1.0): + ''' + parameter init method + ''' + if method in ['uniform']: + params = Parameter(initializer( + Uniform(max_val), shape, ms_type), name=name) + elif method == "one": + params = Parameter(initializer("ones", shape, ms_type), name=name) + elif method == 'zero': + params = Parameter(initializer("zeros", shape, ms_type), name=name) + elif method == "normal": + params = Parameter(Tensor(np.random.normal( + loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name) + return params + + +def init_var_dict(init_args, in_vars): + ''' + var init function + ''' + var_map = {} + _, _max_val = init_args + for _, iterm in enumerate(in_vars): + key, shape, method = iterm + if key not in var_map.keys(): + if method in ['random', 'uniform']: + var_map[key] = Parameter(initializer( + Uniform(_max_val), shape, ms_type), name=key) + elif method == "one": + var_map[key] = Parameter(initializer( + "ones", shape, ms_type), name=key) + elif method == "zero": + var_map[key] = Parameter(initializer( + "zeros", shape, ms_type), name=key) + elif method == 'normal': + var_map[key] = Parameter(Tensor(np.random.normal( + loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=key) + return var_map + + +class DenseLayer(nn.Cell): + """ + Dense Layer for Deep Layer of WideDeep Model; + Containing: activation, matmul, bias_add; + Args: + """ + + def __init__(self, input_dim, output_dim, weight_bias_init, act_str, + keep_prob=0.7, scale_coef=1.0, convert_dtype=True): + super(DenseLayer, self).__init__() + weight_init, bias_init = weight_bias_init + self.weight = init_method( + weight_init, [input_dim, output_dim], name="weight") + self.bias = init_method(bias_init, [output_dim], name="bias") + self.act_func = self._init_activation(act_str) + self.matmul = P.MatMul(transpose_b=False) + self.bias_add = P.BiasAdd() + self.cast = P.Cast() + #self.dropout = Dropout(keep_prob=keep_prob) + self.mul = P.Mul() + self.realDiv = P.RealDiv() + self.scale_coef = scale_coef + self.convert_dtype = convert_dtype + + def _init_activation(self, act_str): + act_str = act_str.lower() + if act_str == "relu": + act_func = P.ReLU() + elif act_str == "sigmoid": + act_func = P.Sigmoid() + elif act_str == "tanh": + act_func = P.Tanh() + return act_func + + def construct(self, x): + x = self.act_func(x) + # if self.training: + # x = self.dropout(x) + x = self.mul(x, self.scale_coef) + if self.convert_dtype: + x = self.cast(x, mstype.float16) + weight = self.cast(self.weight, mstype.float16) + wx = self.matmul(x, weight) + wx = self.cast(wx, mstype.float32) + else: + wx = self.matmul(x, self.weight) + wx = self.realDiv(wx, self.scale_coef) + output = self.bias_add(wx, self.bias) + return output + + +class WideDeepModel(nn.Cell): + """ + From paper: " Wide & Deep Learning for Recommender Systems" + Args: + config (Class): The default config of Wide&Deep + """ + + def __init__(self, config): + super(WideDeepModel, self).__init__() + self.batch_size = config.batch_size + self.field_size = config.field_size + self.vocab_size = config.vocab_size + self.emb_dim = config.emb_dim + self.deep_layer_dims_list = config.deep_layer_dim + self.deep_layer_act = config.deep_layer_act + self.init_args = config.init_args + self.weight_init, self.bias_init = config.weight_bias_init + self.weight_bias_init = config.weight_bias_init + self.emb_init = config.emb_init + self.drop_out = config.dropout_flag + self.keep_prob = config.keep_prob + self.deep_input_dims = self.field_size * self.emb_dim + self.layer_dims = self.deep_layer_dims_list + [1] + self.all_dim_list = [self.deep_input_dims] + self.layer_dims + + init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init), + ('V_l2', [self.vocab_size, self.emb_dim], self.emb_init), + ('Wide_b', [1], self.emb_init)] + var_map = init_var_dict(self.init_args, init_acts) + self.wide_w = var_map["Wide_w"] + self.wide_b = var_map["Wide_b"] + self.embedding_table = var_map["V_l2"] + self.dense_layer_1 = DenseLayer(self.all_dim_list[0], + self.all_dim_list[1], + self.weight_bias_init, + self.deep_layer_act, convert_dtype=True) + self.dense_layer_2 = DenseLayer(self.all_dim_list[1], + self.all_dim_list[2], + self.weight_bias_init, + self.deep_layer_act, convert_dtype=True) + self.dense_layer_3 = DenseLayer(self.all_dim_list[2], + self.all_dim_list[3], + self.weight_bias_init, + self.deep_layer_act, convert_dtype=True) + self.dense_layer_4 = DenseLayer(self.all_dim_list[3], + self.all_dim_list[4], + self.weight_bias_init, + self.deep_layer_act, convert_dtype=True) + self.dense_layer_5 = DenseLayer(self.all_dim_list[4], + self.all_dim_list[5], + self.weight_bias_init, + self.deep_layer_act, convert_dtype=True) + + self.gather_v2 = P.GatherV2() + self.mul = P.Mul() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.reshape = P.Reshape() + self.square = P.Square() + self.shape = P.Shape() + self.tile = P.Tile() + self.concat = P.Concat(axis=1) + self.cast = P.Cast() + + def construct(self, id_hldr, wt_hldr): + """ + Args: + id_hldr: batch ids; + wt_hldr: batch weights; + """ + mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) + # Wide layer + wide_id_weight = self.gather_v2(self.wide_w, id_hldr, 0) + wx = self.mul(wide_id_weight, mask) + wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) + # Deep layer + deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0) + vx = self.mul(deep_id_embs, mask) + deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) + deep_in = self.dense_layer_1(deep_in) + deep_in = self.dense_layer_2(deep_in) + deep_in = self.dense_layer_3(deep_in) + deep_in = self.dense_layer_4(deep_in) + deep_out = self.dense_layer_5(deep_in) + out = wide_out + deep_out + return out, self.embedding_table + + +class NetWithLossClass(nn.Cell): + + """" + Provide WideDeep training loss through network. + Args: + network (Cell): The training network + config (Class): WideDeep config + """ + + def __init__(self, network, config): + super(NetWithLossClass, self).__init__(auto_prefix=False) + self.network = network + self.l2_coef = config.l2_coef + self.loss = P.SigmoidCrossEntropyWithLogits() + self.square = P.Square() + self.reduceMean_false = P.ReduceMean(keep_dims=False) + self.reduceSum_false = P.ReduceSum(keep_dims=False) + + def construct(self, batch_ids, batch_wts, label): + predict, embedding_table = self.network(batch_ids, batch_wts) + log_loss = self.loss(predict, label) + wide_loss = self.reduceMean_false(log_loss) + l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2 + deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v + + return wide_loss, deep_loss + + +class IthOutputCell(nn.Cell): + def __init__(self, network, output_index): + super(IthOutputCell, self).__init__() + self.network = network + self.output_index = output_index + + def construct(self, x1, x2, x3): + predict = self.network(x1, x2, x3)[self.output_index] + return predict + + +class TrainStepWrap(nn.Cell): + """ + Encapsulation class of WideDeep network training. + Append Adam and FTRL optimizers to the training network after that construct + function can be called to create the backward graph. + Args: + network (Cell): the training network. Note that loss function should have been added. + sens (Number): The adjust parameter. Default: 1000.0 + """ + + def __init__(self, network, sens=1000.0): + super(TrainStepWrap, self).__init__() + self.network = network + self.network.set_train() + self.trainable_params = network.trainable_params() + weights_w = [] + weights_d = [] + for params in self.trainable_params: + if 'wide' in params.name: + weights_w.append(params) + else: + weights_d.append(params) + self.weights_w = ParameterTuple(weights_w) + self.weights_d = ParameterTuple(weights_d) + self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, + l1=1e-8, l2=1e-8, initial_accum=1.0) + self.optimizer_d = Adam( + self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) + self.hyper_map = C.HyperMap() + self.grad_w = C.GradOperation('grad_w', get_by_list=True, + sens_param=True) + self.grad_d = C.GradOperation('grad_d', get_by_list=True, + sens_param=True) + self.sens = sens + self.loss_net_w = IthOutputCell(network, output_index=0) + self.loss_net_d = IthOutputCell(network, output_index=1) + + def construct(self, batch_ids, batch_wts, label): + weights_w = self.weights_w + weights_d = self.weights_d + loss_w, loss_d = self.network(batch_ids, batch_wts, label) + sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) + sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) + grads_w = self.grad_w(self.loss_net_w, weights_w)(batch_ids, batch_wts, + label, sens_w) + grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts, + label, sens_d) + return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, + self.optimizer_d(grads_d)) + + +class PredictWithSigmoid(nn.Cell): + def __init__(self, network): + super(PredictWithSigmoid, self).__init__() + self.network = network + self.sigmoid = P.Sigmoid() + + def construct(self, batch_ids, batch_wts, labels): + logits, _, _, = self.network(batch_ids, batch_wts) + pred_probs = self.sigmoid(logits) + return logits, pred_probs, labels diff --git a/example/wide_and_deep/test.py b/example/wide_and_deep/test.py new file mode 100644 index 0000000000000000000000000000000000000000..54969e7c945e05ec3f3040bfaaedca4cfbb2e2bf --- /dev/null +++ b/example/wide_and_deep/test.py @@ -0,0 +1,94 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" test_training """ + +import os + +from mindspore import Model, context +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack, EvalCallBack +from src.datasets import create_dataset +from src.metrics import AUCMetric +from src.config import WideDeepConfig + +context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", + save_graphs=True) + + +def get_WideDeep_net(config): + WideDeep_net = WideDeepModel(config) + + loss_net = NetWithLossClass(WideDeep_net, config) + train_net = TrainStepWrap(loss_net) + eval_net = PredictWithSigmoid(WideDeep_net) + + return train_net, eval_net + + +class ModelBuilder(): + """ + Wide and deep model builder + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def test_eval(config): + """ + test evaluate + """ + data_path = config.data_path + batch_size = config.batch_size + ds_eval = create_dataset(data_path, train_mode=False, epochs=2, + batch_size=batch_size) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + train_net, eval_net = net_builder.get_net(config) + + param_dict = load_checkpoint(config.ckpt_path) + load_param_into_net(eval_net, param_dict) + + auc_metric = AUCMetric() + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) + + model.eval(ds_eval, callbacks=eval_callback) + + +if __name__ == "__main__": + widedeep_config = WideDeepConfig() + widedeep_config.argparse_init() + + test_eval(widedeep_config.widedeep) diff --git a/example/wide_and_deep/train.py b/example/wide_and_deep/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b3996e01cb5dbc67d392d58ac74758c0c8c0da8a --- /dev/null +++ b/example/wide_and_deep/train.py @@ -0,0 +1,85 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" test_training """ +import os +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack +from src.datasets import create_dataset +from src.config import WideDeepConfig + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + + +def get_WideDeep_net(configure): + WideDeep_net = WideDeepModel(configure) + + loss_net = NetWithLossClass(WideDeep_net, configure) + train_net = TrainStepWrap(loss_net) + eval_net = PredictWithSigmoid(WideDeep_net) + + return train_net, eval_net + + +class ModelBuilder(): + """ + Build the model. + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, configure): + return get_WideDeep_net(configure) + + +def test_train(configure): + """ + test_train + """ + data_path = configure.data_path + batch_size = configure.batch_size + epochs = configure.epochs + ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + + net_builder = ModelBuilder() + train_net, _ = net_builder.get_net(configure) + train_net.set_train() + + model = Model(train_net) + callback = LossCallBack(config=configure) + ckptconfig = CheckpointConfig(save_checkpoint_steps=1, + keep_checkpoint_max=5) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig) + model.train(epochs, ds_train, callbacks=[callback, ckpoint_cb]) + + +if __name__ == "__main__": + config = WideDeepConfig() + config.argparse_init() + + test_train(config) diff --git a/example/wide_and_deep/train_and_test.py b/example/wide_and_deep/train_and_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9832b84bc835860cacf5edf25cc754cba151ac --- /dev/null +++ b/example/wide_and_deep/train_and_test.py @@ -0,0 +1,97 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" test_training """ +import os + +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack, EvalCallBack +from src.datasets import create_dataset +from src.metrics import AUCMetric +from src.config import WideDeepConfig + +context.set_context(mode=context.GRAPH_MODE, device_target="Davinci") + + +def get_WideDeep_net(config): + WideDeep_net = WideDeepModel(config) + + loss_net = NetWithLossClass(WideDeep_net, config) + train_net = TrainStepWrap(loss_net) + eval_net = PredictWithSigmoid(WideDeep_net) + + return train_net, eval_net + + +class ModelBuilder(): + """ + ModelBuilder + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def test_train_eval(config): + """ + test_train_eval + """ + data_path = config.data_path + batch_size = config.batch_size + epochs = config.epochs + ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) + ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + + train_net, eval_net = net_builder.get_net(config) + train_net.set_train() + auc_metric = AUCMetric() + + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) + + callback = LossCallBack(config=config) + ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) + + out = model.eval(ds_eval) + print("=====" * 5 + "model.eval() initialized: {}".format(out)) + model.train(epochs, ds_train, callbacks=[eval_callback, callback, ckpoint_cb]) + + +if __name__ == "__main__": + wide_deep_config = WideDeepConfig() + wide_deep_config.argparse_init() + + test_train_eval(wide_deep_config) diff --git a/example/wide_and_deep/train_and_test_multinpu.py b/example/wide_and_deep/train_and_test_multinpu.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4b99f5462b25f22c99bba60b3a5c72c9c55950 --- /dev/null +++ b/example/wide_and_deep/train_and_test_multinpu.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_multinpu.""" + + +import os +import sys +import numpy as np +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor +from mindspore.train import ParallelMode +from mindspore.communication.management import get_rank, get_group_size, init + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack, EvalCallBack +from src.datasets import create_dataset +from src.metrics import AUCMetric +from src.config import WideDeepConfig + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +context.set_context(mode=GRAPH_MODE, device_target="Davinci", save_graph=True) +context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) +init() + + + +def get_WideDeep_net(config): + WideDeep_net = WideDeepModel(config) + loss_net = NetWithLossClass(WideDeep_net, config) + train_net = TrainStepWrap(loss_net) + eval_net = PredictWithSigmoid(WideDeep_net) + return train_net, eval_net + + +class ModelBuilder(): + """ + ModelBuilder + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def test_train_eval(): + """ + test_train_eval + """ + np.random.seed(1000) + config = WideDeepConfig + data_path = Config.data_path + batch_size = config.batch_size + epochs = config.epochs + print("epochs is {}".format(epochs)) + ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, + batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) + ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, + batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + + train_net, eval_net = net_builder.get_net(config) + train_net.set_train() + auc_metric = AUCMetric() + + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) + + callback = LossCallBack(config=config) + ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', + directory=config.ckpt_path, config=ckptconfig) + model.train(epochs, ds_train, + callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])