From b8e1786652f1c9f9755ef1d88982c649de47e663 Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Thu, 9 Jul 2020 11:51:39 +0800 Subject: [PATCH] add rank model BST (#134) * add rank model BST * update readme --- README.md | 1 + README_CN.md | 1 + models/rank/BST/__init__.py | 13 + models/rank/BST/config.yaml | 84 +++++ models/rank/BST/data/build_dataset.py | 116 ++++++ models/rank/BST/data/convert_pd.py | 41 +++ models/rank/BST/data/data_process.sh | 15 + models/rank/BST/data/remap_id.py | 62 ++++ .../BST/data/train_data/paddle_train.100.txt | 100 +++++ models/rank/BST/model.py | 347 ++++++++++++++++++ models/rank/readme.md | 1 + 11 files changed, 781 insertions(+) create mode 100755 models/rank/BST/__init__.py create mode 100755 models/rank/BST/config.yaml create mode 100755 models/rank/BST/data/build_dataset.py create mode 100755 models/rank/BST/data/convert_pd.py create mode 100755 models/rank/BST/data/data_process.sh create mode 100755 models/rank/BST/data/remap_id.py create mode 100755 models/rank/BST/data/train_data/paddle_train.100.txt create mode 100755 models/rank/BST/model.py diff --git a/README.md b/README.md index b5939cc1..86990286 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ | Rank | [xDeepFM](models/rank/xdeepfm/model.py) | ✓ | x | ✓ | x | [KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/3219819.3220023) | | Rank | [DIN](models/rank/din/model.py) | ✓ | x | ✓ | x | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://dl.acm.org/doi/pdf/10.1145/3219819.3219823) | | Rank | [DIEN](models/rank/dien/model.py) | ✓ | x | ✓ | x | [AAAI 2019][Deep Interest Evolution Network for Click-Through Rate Prediction](https://www.aaai.org/ojs/index.php/AAAI/article/view/4545/4423) | + | Rank | [BST](models/rank/BST/model.py) | ✓ | x | ✓ | x | [DLP-KDD 2019][Behavior Sequence Transformer for E-commerce Recommendation in Alibaba](https://arxiv.org/pdf/1905.06874v1.pdf) | | Rank | [AutoInt](models/rank/AutoInt/model.py) | ✓ | x | ✓ | x | [CIKM 2019][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/pdf/1810.11921.pdf) | | Rank | [Wide&Deep](models/rank/wide_deep/model.py) | ✓ | x | ✓ | x | [DLRS 2016][Wide & Deep Learning for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/2988450.2988454) | | Rank | [FGCNN](models/rank/fgcnn/model.py) | ✓ | ✓ | ✓ | ✓ | [WWW 2019][Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1904.04447.pdf) | diff --git a/README_CN.md b/README_CN.md index 3b5fe858..3f3f8f0e 100644 --- a/README_CN.md +++ b/README_CN.md @@ -61,6 +61,7 @@ | 排序 | [xDeepFM](models/rank/xdeepfm/model.py) | ✓ | x | ✓ | x | [KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/3219819.3220023) | | 排序 | [DIN](models/rank/din/model.py) | ✓ | x | ✓ | x | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://dl.acm.org/doi/pdf/10.1145/3219819.3219823) | | 排序 | [DIEN](models/rank/dien/model.py) | ✓ | x | ✓ | x | [AAAI 2019][Deep Interest Evolution Network for Click-Through Rate Prediction](https://www.aaai.org/ojs/index.php/AAAI/article/view/4545/4423) | + | 排序 | [BST](models/rank/BST/model.py) | ✓ | x | ✓ | x | [DLP_KDD 2019][Behavior Sequence Transformer for E-commerce Recommendation in Alibaba](https://arxiv.org/pdf/1905.06874v1.pdf) | | 排序 | [AutoInt](models/rank/AutoInt/model.py) | ✓ | x | ✓ | x | [CIKM 2019][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/pdf/1810.11921.pdf) | | 排序 | [Wide&Deep](models/rank/wide_deep/model.py) | ✓ | x | ✓ | x | [DLRS 2016][Wide & Deep Learning for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/2988450.2988454) | | 排序 | [FGCNN](models/rank/fgcnn/model.py) | ✓ | ✓ | ✓ | ✓ | [WWW 2019][Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1904.04447.pdf) | diff --git a/models/rank/BST/__init__.py b/models/rank/BST/__init__.py new file mode 100755 index 00000000..abf198b9 --- /dev/null +++ b/models/rank/BST/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/models/rank/BST/config.yaml b/models/rank/BST/config.yaml new file mode 100755 index 00000000..73e39f19 --- /dev/null +++ b/models/rank/BST/config.yaml @@ -0,0 +1,84 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# global settings +debug: false +workspace: "paddlerec.models.rank.BST" + +dataset: +- name: sample_1 + type: DataLoader + batch_size: 5 + data_path: "{workspace}/data/train_data" + sparse_slots: "label history cate position target target_cate target_position" +- name: infer_sample + type: DataLoader + batch_size: 5 + data_path: "{workspace}/data/train_data" + sparse_slots: "label history cate position target target_cate target_position" + +hyper_parameters: + optimizer: + class: SGD + learning_rate: 0.0001 + use_DataLoader: True + item_emb_size: 96 + cat_emb_size: 96 + position_emb_size: 96 + is_sparse: False + item_count: 63001 + cat_count: 801 + position_count: 5001 + n_encoder_layers: 1 + d_model: 288 + d_key: 48 + d_value: 48 + n_head: 6 + dropout_rate: 0 + postprocess_cmd: "da" + prepostprocess_dropout: 0 + d_inner_hid: 512 + relu_dropout: 0.0 + act: "relu" + fc_sizes: [1024, 512, 256] + + +mode: train_runner + +runner: + - name: train_runner + class: train + epochs: 1 + device: cpu + init_model_path: "" + save_checkpoint_interval: 1 + save_inference_interval: 1 + save_checkpoint_path: "increment_BST" + save_inference_path: "inference_BST" + print_interval: 1 + - name: infer_runner + class: infer + device: cpu + init_model_path: "increment_BST/0" + print_interval: 1 + +phase: +- name: phase1 + model: "{workspace}/model.py" + dataset_name: sample_1 + thread_num: 1 +#- name: infer_phase +# model: "{workspace}/model.py" +# dataset_name: infer_sample +# thread_num: 1 diff --git a/models/rank/BST/data/build_dataset.py b/models/rank/BST/data/build_dataset.py new file mode 100755 index 00000000..137d8652 --- /dev/null +++ b/models/rank/BST/data/build_dataset.py @@ -0,0 +1,116 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import random +import pickle + +random.seed(1234) + +print("read and process data") + +with open('./raw_data/remap.pkl', 'rb') as f: + reviews_df = pickle.load(f) + cate_list = pickle.load(f) + user_count, item_count, cate_count, example_count = pickle.load(f) + +train_set = [] +test_set = [] +for reviewerID, hist in reviews_df.groupby('reviewerID'): + pos_list = hist['asin'].tolist() + time_list = hist['unixReviewTime'].tolist() + + def gen_neg(): + neg = pos_list[0] + while neg in pos_list: + neg = random.randint(0, item_count - 1) + return neg + + neg_list = [gen_neg() for i in range(len(pos_list))] + + for i in range(1, len(pos_list)): + hist = pos_list[:i] + # set maximum position value + time_seq = [ + min(int((time_list[i] - time_list[j]) / (3600 * 24)), 5000) + for j in range(i) + ] + if i != len(pos_list) - 1: + train_set.append((reviewerID, hist, pos_list[i], 1, time_seq)) + train_set.append((reviewerID, hist, neg_list[i], 0, time_seq)) + else: + label = (pos_list[i], neg_list[i]) + test_set.append((reviewerID, hist, label, time_seq)) + +random.shuffle(train_set) +random.shuffle(test_set) + +assert len(test_set) == user_count + + +def print_to_file(data, fout, slot): + if not isinstance(data, list): + data = [data] + for i in range(len(data)): + fout.write(slot + ":" + str(data[i])) + fout.write(' ') + + +print("make train data") +with open("paddle_train.txt", "w") as fout: + for line in train_set: + history = line[1] + target = line[2] + label = line[3] + position = line[4] + cate = [cate_list[x] for x in history] + print_to_file(history, fout, "history") + print_to_file(cate, fout, "cate") + print_to_file(position, fout, "position") + print_to_file(target, fout, "target") + print_to_file(cate_list[target], fout, "target_cate") + print_to_file(0, fout, "target_position") + print_to_file(label, fout, "label") + fout.write("\n") + +print("make test data") +with open("paddle_test.txt", "w") as fout: + for line in test_set: + history = line[1] + target = line[2] + position = line[3] + cate = [cate_list[x] for x in history] + + print_to_file(history, fout, "history") + print_to_file(cate, fout, "cate") + print_to_file(position, fout, "position") + print_to_file(target[0], fout, "target") + print_to_file(cate_list[target[0]], fout, "target_cate") + print_to_file(0, fout, "target_position") + fout.write("label:1\n") + + print_to_file(history, fout, "history") + print_to_file(cate, fout, "cate") + print_to_file(position, fout, "position") + print_to_file(target[0], fout, "target") + print_to_file(cate_list[target[1]], fout, "target_cate") + print_to_file(0, fout, "target_position") + fout.write("label:0\n") + +print("make config data") +with open('config.txt', 'w') as f: + f.write(str(user_count) + "\n") + f.write(str(item_count) + "\n") + f.write(str(cate_count) + "\n") + f.wrire(str(50000) + "\n") diff --git a/models/rank/BST/data/convert_pd.py b/models/rank/BST/data/convert_pd.py new file mode 100755 index 00000000..a66290e1 --- /dev/null +++ b/models/rank/BST/data/convert_pd.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import pickle +import pandas as pd + + +def to_df(file_path): + with open(file_path, 'r') as fin: + df = {} + i = 0 + for line in fin: + df[i] = eval(line) + i += 1 + df = pd.DataFrame.from_dict(df, orient='index') + return df + + +print("start to analyse reviews_Electronics_5.json") +reviews_df = to_df('./raw_data/reviews_Electronics_5.json') +with open('./raw_data/reviews.pkl', 'wb') as f: + pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL) + +print("start to analyse meta_Electronics.json") +meta_df = to_df('./raw_data/meta_Electronics.json') +meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())] +meta_df = meta_df.reset_index(drop=True) +with open('./raw_data/meta.pkl', 'wb') as f: + pickle.dump(meta_df, f, pickle.HIGHEST_PROTOCOL) diff --git a/models/rank/BST/data/data_process.sh b/models/rank/BST/data/data_process.sh new file mode 100755 index 00000000..7bcfc55f --- /dev/null +++ b/models/rank/BST/data/data_process.sh @@ -0,0 +1,15 @@ +#! /bin/bash + +set -e +echo "begin download data" +mkdir raw_data +cd raw_data +wget -c http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics_5.json.gz +gzip -d reviews_Electronics_5.json.gz +wget -c http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Electronics.json.gz +gzip -d meta_Electronics.json.gz +echo "download data successfully" + +cd .. +python convert_pd.py +python remap_id.py diff --git a/models/rank/BST/data/remap_id.py b/models/rank/BST/data/remap_id.py new file mode 100755 index 00000000..ee6983d7 --- /dev/null +++ b/models/rank/BST/data/remap_id.py @@ -0,0 +1,62 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import random +import pickle +import numpy as np + +random.seed(1234) + +with open('./raw_data/reviews.pkl', 'rb') as f: + reviews_df = pickle.load(f) + reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']] +with open('./raw_data/meta.pkl', 'rb') as f: + meta_df = pickle.load(f) + meta_df = meta_df[['asin', 'categories']] + meta_df['categories'] = meta_df['categories'].map(lambda x: x[-1][-1]) + + +def build_map(df, col_name): + key = sorted(df[col_name].unique().tolist()) + m = dict(zip(key, range(len(key)))) + df[col_name] = df[col_name].map(lambda x: m[x]) + return m, key + + +asin_map, asin_key = build_map(meta_df, 'asin') +cate_map, cate_key = build_map(meta_df, 'categories') +revi_map, revi_key = build_map(reviews_df, 'reviewerID') + +user_count, item_count, cate_count, example_count =\ + len(revi_map), len(asin_map), len(cate_map), reviews_df.shape[0] +print('user_count: %d\titem_count: %d\tcate_count: %d\texample_count: %d' % + (user_count, item_count, cate_count, example_count)) + +meta_df = meta_df.sort_values('asin') +meta_df = meta_df.reset_index(drop=True) +reviews_df['asin'] = reviews_df['asin'].map(lambda x: asin_map[x]) +reviews_df = reviews_df.sort_values(['reviewerID', 'unixReviewTime']) +reviews_df = reviews_df.reset_index(drop=True) +reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']] + +cate_list = [meta_df['categories'][i] for i in range(len(asin_map))] +cate_list = np.array(cate_list, dtype=np.int32) + +with open('./raw_data/remap.pkl', 'wb') as f: + pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL) # uid, iid + pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL) # cid of iid line + pickle.dump((user_count, item_count, cate_count, example_count), f, + pickle.HIGHEST_PROTOCOL) + pickle.dump((asin_key, cate_key, revi_key), f, pickle.HIGHEST_PROTOCOL) diff --git a/models/rank/BST/data/train_data/paddle_train.100.txt b/models/rank/BST/data/train_data/paddle_train.100.txt new file mode 100755 index 00000000..a65d9341 --- /dev/null +++ b/models/rank/BST/data/train_data/paddle_train.100.txt @@ -0,0 +1,100 @@ +history:3737 history:19450 cate:288 cate:196 position:518 position:158 target:18486 target_cate:674 label:1 +history:3647 history:4342 history:6855 history:3805 cate:281 cate:463 cate:558 cate:674 position:242 position:216 position:17 position:5 target:4206 target_cate:463 label:1 +history:1805 history:4309 cate:87 cate:87 position:61 position:0 target:21354 target_cate:556 label:1 +history:18209 history:20753 cate:649 cate:241 position:0 position:0 target:51924 target_cate:610 label:0 +history:13150 cate:351 position:505 target:41455 target_cate:792 label:1 +history:35120 history:40418 cate:157 cate:714 position:0 position:0 target:52035 target_cate:724 label:0 +history:13515 history:20363 history:25356 history:26891 history:24200 history:11694 history:33378 history:34483 history:35370 history:27311 history:40689 history:33319 history:28819 cate:558 cate:123 cate:61 cate:110 cate:738 cate:692 cate:110 cate:629 cate:714 cate:463 cate:281 cate:142 cate:382 position:1612 position:991 position:815 position:668 position:639 position:508 position:456 position:431 position:409 position:222 position:221 position:74 position:34 target:45554 target_cate:558 label:1 +history:19254 history:9021 history:28156 history:19193 history:24602 history:31171 cate:189 cate:462 cate:140 cate:474 cate:157 cate:614 position:375 position:144 position:141 position:0 position:0 position:0 target:48895 target_cate:350 label:1 +history:4716 cate:194 position:2457 target:32497 target_cate:484 label:1 +history:43799 history:47108 cate:368 cate:140 position:181 position:105 target:3503 target_cate:25 label:0 +history:20554 history:41800 history:1582 history:1951 cate:339 cate:776 cate:694 cate:703 position:35 position:35 position:0 position:0 target:4320 target_cate:234 label:0 +history:39713 history:44272 history:45136 history:11687 cate:339 cate:339 cate:339 cate:140 position:40 position:40 position:40 position:0 target:885 target_cate:168 label:0 +history:14398 history:33997 cate:756 cate:347 position:73 position:73 target:20438 target_cate:703 label:1 +history:29341 history:25727 cate:142 cate:616 position:839 position:0 target:4170 target_cate:512 label:0 +history:12197 history:10212 cate:558 cate:694 position:1253 position:677 target:31559 target_cate:24 label:0 +history:11551 cate:351 position:47 target:53485 target_cate:436 label:1 +history:4553 cate:196 position:88 target:7331 target_cate:158 label:1 +history:15190 history:19994 history:33946 history:30716 history:31879 history:45178 history:51598 history:46814 cate:249 cate:498 cate:612 cate:142 cate:746 cate:746 cate:558 cate:174 position:1912 position:1275 position:1170 position:1122 position:773 position:773 position:329 position:291 target:24353 target_cate:251 label:0 +history:4931 history:2200 history:8338 history:23530 cate:785 cate:792 cate:277 cate:523 position:1360 position:975 position:975 position:586 target:3525 target_cate:251 label:0 +history:8881 history:13274 history:12683 history:14696 history:27693 history:1395 history:44373 history:59704 history:27762 history:54268 history:30326 history:11811 history:45371 history:51598 history:55859 history:56039 history:57678 history:47250 history:2073 history:38932 cate:479 cate:558 cate:190 cate:708 cate:335 cate:684 cate:339 cate:725 cate:446 cate:446 cate:44 cate:575 cate:280 cate:558 cate:262 cate:197 cate:368 cate:111 cate:749 cate:188 position:2065 position:2065 position:1292 position:1108 position:647 position:343 position:343 position:343 position:257 position:257 position:143 position:76 position:76 position:76 position:76 position:76 position:76 position:58 position:6 position:6 target:12361 target_cate:616 label:1 +history:16297 history:16797 history:18629 history:20922 history:16727 history:33946 history:51165 history:36796 cate:281 cate:436 cate:462 cate:339 cate:611 cate:612 cate:288 cate:64 position:1324 position:1324 position:1324 position:1118 position:183 position:133 position:6 position:4 target:34724 target_cate:288 label:1 +history:22237 cate:188 position:339 target:40786 target_cate:637 label:0 +history:5396 history:39993 history:42681 history:49832 history:11208 history:34954 history:36523 history:45523 history:51618 cate:351 cate:339 cate:687 cate:281 cate:708 cate:142 cate:629 cate:656 cate:142 position:1117 position:290 position:276 position:191 position:144 position:144 position:120 position:66 position:66 target:38201 target_cate:571 label:0 +history:8881 history:9029 history:17043 history:16620 history:15021 history:32706 cate:479 cate:110 cate:110 cate:749 cate:598 cate:251 position:1218 position:1218 position:790 position:695 position:264 position:1 target:34941 target_cate:657 label:0 +history:53255 cate:444 position:232 target:37953 target_cate:724 label:1 +history:1010 history:4172 history:8613 history:11562 history:11709 history:13118 history:2027 history:15446 cate:674 cate:606 cate:708 cate:436 cate:179 cate:179 cate:692 cate:436 position:324 position:323 position:323 position:323 position:323 position:308 position:307 position:307 target:36998 target_cate:703 label:0 +history:22357 history:24305 history:15222 history:19254 history:22914 cate:189 cate:504 cate:113 cate:189 cate:714 position:321 position:321 position:232 position:232 position:232 target:18201 target_cate:398 label:1 +history:1905 cate:694 position:0 target:23877 target_cate:347 label:1 +history:8444 history:17868 cate:765 cate:712 position:454 position:0 target:50732 target_cate:44 label:0 +history:42301 history:26186 history:38086 cate:142 cate:450 cate:744 position:164 position:0 position:0 target:61547 target_cate:714 label:0 +history:18156 history:35717 history:32070 history:45650 history:47208 history:20975 history:36409 history:44856 history:48072 history:15860 history:47043 history:53289 history:53314 history:33470 history:47926 cate:157 cate:281 cate:650 cate:142 cate:749 cate:291 cate:707 cate:714 cate:157 cate:205 cate:388 cate:474 cate:708 cate:498 cate:495 position:546 position:506 position:296 position:296 position:263 position:253 position:253 position:221 position:121 position:26 position:26 position:26 position:26 position:0 position:0 target:48170 target_cate:746 label:1 +history:56219 cate:108 position:0 target:1988 target_cate:389 label:0 +history:22907 cate:83 position:353 target:752 target_cate:175 label:0 +history:22009 history:32410 history:42987 history:48720 history:683 history:1289 history:2731 history:4736 history:6306 history:8442 history:8946 history:9928 history:11536 history:14947 history:15793 history:16694 history:21736 history:25156 history:25797 history:25874 history:26573 history:30318 history:33946 history:35420 history:1492 history:5236 history:5555 history:6625 history:8867 history:9638 history:11443 history:20225 history:25965 history:27273 history:29001 history:35302 history:42336 history:43347 history:36907 history:2012 cate:317 cate:462 cate:291 cate:142 cate:694 cate:10 cate:574 cate:278 cate:708 cate:281 cate:131 cate:142 cate:367 cate:281 cate:258 cate:345 cate:616 cate:708 cate:111 cate:115 cate:339 cate:113 cate:612 cate:24 cate:368 cate:616 cate:39 cate:197 cate:44 cate:214 cate:558 cate:108 cate:616 cate:558 cate:210 cate:210 cate:142 cate:142 cate:262 cate:351 position:390 position:390 position:390 position:390 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:366 position:339 position:333 position:167 target:25540 target_cate:701 label:0 +history:20434 cate:196 position:610 target:18056 target_cate:189 label:0 +history:628 history:5461 cate:194 cate:234 position:294 position:74 target:43677 target_cate:351 label:0 +history:16953 history:15149 history:45143 history:23587 history:5094 history:25105 history:51913 history:54645 cate:484 cate:281 cate:449 cate:792 cate:524 cate:395 cate:388 cate:731 position:1134 position:668 position:626 position:409 position:285 position:285 position:285 position:42 target:57655 target_cate:75 label:1 +history:13584 history:7509 cate:234 cate:744 position:1187 position:231 target:33062 target_cate:749 label:1 +history:170 history:208 history:77 history:109 history:738 history:742 history:1118 history:15349 history:255 history:12067 history:21643 history:55453 cate:330 cate:559 cate:744 cate:115 cate:558 cate:674 cate:111 cate:351 cate:694 cate:694 cate:746 cate:111 position:4920 position:4726 position:4585 position:4585 position:4585 position:4584 position:4108 position:1418 position:1326 position:274 position:89 position:88 target:9821 target_cate:694 label:1 +history:4970 history:16672 cate:540 cate:746 position:416 position:120 target:25685 target_cate:666 label:1 +history:17240 history:60546 cate:708 cate:629 position:165 position:41 target:42110 target_cate:142 label:1 +history:31503 history:31226 history:50628 history:22444 cate:142 cate:156 cate:142 cate:203 position:187 position:162 position:109 position:0 target:47812 target_cate:749 label:0 +history:2443 history:1763 history:3403 history:4225 history:8951 cate:25 cate:707 cate:351 cate:177 cate:351 position:1397 position:1113 position:973 position:637 position:254 target:7954 target_cate:351 label:1 +history:3748 cate:351 position:1086 target:9171 target_cate:657 label:1 +history:1755 history:26204 history:42716 history:32991 cate:446 cate:188 cate:497 cate:746 position:440 position:184 position:91 position:52 target:23910 target_cate:395 label:1 +history:20637 history:27122 cate:558 cate:44 position:1122 position:0 target:19669 target_cate:301 label:0 +history:406 history:872 history:306 history:218 history:883 history:1372 history:1705 history:1709 history:7774 history:2376 history:2879 history:2881 history:13329 history:4992 history:13594 history:11106 history:7131 history:8631 history:1736 history:17585 history:2568 history:16896 history:21971 history:10296 history:22361 history:24108 history:23300 history:11793 history:25351 history:2648 history:24593 history:12692 history:23883 history:25345 history:27129 history:26321 history:21627 history:20738 history:17784 history:28785 history:29281 history:28366 history:24723 history:24319 history:12083 history:29882 history:29974 history:30443 history:30428 history:17072 history:9783 history:16700 history:29421 history:32253 history:28830 history:31299 history:28792 history:33931 history:24973 history:33112 history:21717 history:28339 history:23978 history:18649 history:1841 history:17635 history:19696 history:37448 history:20862 history:30492 history:35736 history:37450 history:2633 history:8675 history:17412 history:25960 history:28389 history:31032 history:37157 history:14555 history:4996 history:33388 history:33393 history:36237 history:38946 history:22793 history:24337 history:34963 history:38819 history:41165 history:39551 history:43019 history:15570 history:25129 history:34593 history:38385 history:42915 history:41407 history:29907 history:31289 history:44229 history:24267 history:34975 history:39462 history:33274 history:43251 history:38302 history:35502 history:44056 history:44675 history:45233 history:47690 history:33472 history:50149 history:29409 history:47183 history:49188 history:48192 history:50628 history:24103 history:28313 history:28358 history:38882 history:44330 history:44346 history:2019 history:2484 history:2675 history:26396 history:48143 history:46039 history:47722 history:48559 history:41719 history:41720 history:43920 history:41983 history:51235 history:34964 history:27287 history:51915 history:33586 history:43630 history:47258 history:52137 history:40954 history:35120 history:29572 history:42405 history:53559 history:44900 history:45761 cate:241 cate:558 cate:395 cate:368 cate:498 cate:110 cate:463 cate:611 cate:558 cate:106 cate:10 cate:112 cate:251 cate:241 cate:48 cate:112 cate:601 cate:674 cate:241 cate:347 cate:733 cate:502 cate:194 cate:119 cate:179 cate:179 cate:578 cate:692 cate:281 cate:115 cate:523 cate:113 cate:281 cate:35 cate:765 cate:196 cate:339 cate:115 cate:90 cate:164 cate:790 cate:708 cate:142 cate:115 cate:342 cate:351 cate:391 cate:281 cate:48 cate:119 cate:74 cate:505 cate:606 cate:68 cate:239 cate:687 cate:687 cate:281 cate:110 cate:281 cate:449 cate:351 cate:38 cate:351 cate:164 cate:176 cate:449 cate:115 cate:70 cate:25 cate:687 cate:115 cate:39 cate:756 cate:35 cate:175 cate:704 cate:119 cate:38 cate:53 cate:115 cate:38 cate:38 cate:142 cate:262 cate:188 cate:614 cate:277 cate:388 cate:615 cate:49 cate:738 cate:106 cate:733 cate:486 cate:666 cate:571 cate:385 cate:708 cate:119 cate:331 cate:463 cate:578 cate:288 cate:142 cate:106 cate:611 cate:611 cate:39 cate:523 cate:388 cate:142 cate:726 cate:702 cate:498 cate:61 cate:142 cate:714 cate:142 cate:654 cate:277 cate:733 cate:603 cate:498 cate:299 cate:97 cate:726 cate:115 cate:637 cate:703 cate:558 cate:74 cate:629 cate:142 cate:142 cate:347 cate:629 cate:746 cate:277 cate:8 cate:49 cate:389 cate:629 cate:408 cate:733 cate:345 cate:157 cate:704 cate:115 cate:398 cate:611 cate:239 position:3925 position:3925 position:3909 position:3897 position:3879 position:3644 position:3611 position:3524 position:2264 position:1913 position:1730 position:1730 position:1684 position:1657 position:1643 position:1626 position:1566 position:1430 position:1375 position:1351 position:1298 position:1298 position:1221 position:1217 position:1177 position:1149 position:1142 position:1141 position:1083 position:1079 position:1067 position:1045 position:1031 position:997 position:994 position:993 position:987 position:968 position:946 position:945 position:905 position:904 position:903 position:897 position:856 position:855 position:813 position:813 position:801 position:799 position:798 position:791 position:791 position:767 position:765 position:761 position:756 position:751 position:747 position:730 position:672 position:659 position:652 position:620 position:619 position:619 position:597 position:596 position:582 position:555 position:555 position:532 position:484 position:484 position:484 position:483 position:483 position:468 position:468 position:467 position:454 position:454 position:454 position:441 position:427 position:409 position:409 position:409 position:409 position:409 position:387 position:387 position:381 position:381 position:381 position:360 position:360 position:357 position:355 position:337 position:332 position:317 position:294 position:271 position:213 position:206 position:204 position:202 position:202 position:182 position:182 position:173 position:154 position:142 position:135 position:114 position:110 position:107 position:107 position:95 position:95 position:95 position:95 position:94 position:92 position:90 position:90 position:90 position:90 position:90 position:86 position:86 position:86 position:84 position:84 position:84 position:83 position:83 position:80 position:65 position:51 position:41 position:23 position:23 position:23 position:22 position:18 position:7 position:3 position:3 position:0 position:0 target:49174 target_cate:368 label:0 +history:29206 history:60955 cate:351 cate:684 position:32 position:32 target:61590 target_cate:76 label:1 +history:8427 history:9692 history:4411 history:3266 history:18234 history:22774 cate:746 cate:281 cate:396 cate:651 cate:446 cate:44 position:1204 position:1129 position:808 position:622 position:134 position:134 target:23393 target_cate:351 label:0 +history:13051 history:15844 history:9347 history:21973 history:18365 history:24220 history:28429 history:4799 history:27488 history:21623 history:13870 history:29346 history:27208 history:31075 history:31635 history:28390 history:30777 history:29334 history:33438 history:16469 history:29423 history:29237 history:25527 history:34808 history:37656 history:21324 history:38263 history:6699 history:33167 history:9295 history:40828 history:18894 cate:339 cate:342 cate:657 cate:194 cate:20 cate:466 cate:179 cate:225 cate:436 cate:364 cate:707 cate:115 cate:36 cate:523 cate:351 cate:674 cate:694 cate:391 cate:674 cate:500 cate:342 cate:216 cate:707 cate:345 cate:616 cate:495 cate:436 cate:363 cate:395 cate:189 cate:203 cate:766 position:1400 position:1032 position:849 position:827 position:804 position:469 position:467 position:463 position:460 position:456 position:455 position:451 position:371 position:371 position:371 position:315 position:315 position:314 position:311 position:287 position:282 position:281 position:239 position:105 position:105 position:70 position:70 position:67 position:56 position:45 position:42 position:6 target:56816 target_cate:396 label:0 +history:5653 history:18042 history:21137 history:17277 history:23847 history:25109 history:21837 history:17163 history:22786 history:27380 history:20789 history:27737 history:30164 history:36402 history:37166 history:38647 history:31746 history:38915 history:38366 history:11151 history:43757 history:38284 history:29817 history:41717 history:41899 history:43279 history:47539 history:37850 history:39789 history:43817 history:11208 history:53361 history:29247 history:51483 history:39940 history:50917 history:53618 history:44055 history:48997 cate:593 cate:251 cate:616 cate:110 cate:110 cate:110 cate:110 cate:105 cate:436 cate:558 cate:311 cate:142 cate:603 cate:738 cate:398 cate:766 cate:1 cate:351 cate:142 cate:584 cate:674 cate:597 cate:142 cate:483 cate:351 cate:157 cate:373 cate:142 cate:629 cate:39 cate:708 cate:251 cate:339 cate:142 cate:262 cate:1 cate:113 cate:142 cate:462 position:1285 position:1258 position:1252 position:1206 position:1206 position:1206 position:1205 position:1194 position:1187 position:992 position:804 position:791 position:703 position:670 position:640 position:549 position:548 position:542 position:489 position:480 position:479 position:455 position:422 position:393 position:319 position:296 position:274 position:266 position:266 position:266 position:222 position:141 position:127 position:127 position:114 position:88 position:56 position:22 position:8 target:13418 target_cate:558 label:0 +history:8719 history:11172 cate:311 cate:217 position:0 position:0 target:11707 target_cate:179 label:1 +history:14968 history:8297 history:22914 history:5998 history:20253 history:41425 history:42664 history:46745 history:51179 history:33481 history:46814 history:55135 history:53124 history:61559 cate:463 cate:766 cate:714 cate:486 cate:628 cate:444 cate:281 cate:714 cate:142 cate:242 cate:174 cate:118 cate:714 cate:714 position:2006 position:1413 position:1323 position:1148 position:977 position:777 position:589 position:487 position:486 position:403 position:349 position:297 position:78 position:12 target:61908 target_cate:714 label:1 +history:61119 cate:714 position:99 target:22907 target_cate:83 label:0 +history:26172 cate:157 position:258 target:54529 target_cate:44 label:0 +history:13830 history:10377 history:8193 history:16072 history:13543 history:18741 history:24205 history:18281 history:37272 history:27784 history:16658 history:27884 cate:384 cate:739 cate:558 cate:739 cate:135 cate:347 cate:558 cate:687 cate:498 cate:142 cate:197 cate:746 position:1447 position:1443 position:1380 position:1312 position:936 position:876 position:695 position:523 position:55 position:25 position:24 position:20 target:34463 target_cate:177 label:1 +history:20842 history:11756 history:22110 history:30562 history:30697 cate:189 cate:68 cate:483 cate:776 cate:225 position:516 position:55 position:21 position:21 position:21 target:49113 target_cate:483 label:0 +history:13646 history:46782 history:54138 cate:142 cate:798 cate:142 position:604 position:346 position:200 target:43698 target_cate:347 label:0 +history:36434 cate:241 position:31 target:51537 target_cate:629 label:0 +history:44121 history:35325 cate:397 cate:653 position:809 position:0 target:43399 target_cate:397 label:1 +history:6438 history:11107 history:20073 history:25026 history:24434 history:35533 history:6318 history:25028 history:28352 history:32359 history:25734 history:26280 history:41466 history:25192 history:1909 history:11753 history:17770 history:24301 history:1728 history:9693 history:36444 history:40256 history:17961 history:36780 history:41093 history:8788 history:439 history:46397 history:46269 history:50462 history:40395 history:437 history:2582 history:4455 history:12361 history:14325 history:22294 history:26153 history:26607 history:29205 history:29878 history:33491 history:38795 history:41585 history:45480 history:51567 history:54245 history:19796 history:52446 cate:356 cate:194 cate:389 cate:89 cate:474 cate:330 cate:347 cate:384 cate:330 cate:90 cate:19 cate:385 cate:177 cate:68 cate:624 cate:68 cate:674 cate:463 cate:624 cate:194 cate:177 cate:389 cate:197 cate:642 cate:239 cate:111 cate:115 cate:113 cate:48 cate:251 cate:554 cate:115 cate:36 cate:163 cate:616 cate:524 cate:84 cate:190 cate:465 cate:398 cate:89 cate:166 cate:113 cate:330 cate:616 cate:449 cate:90 cate:140 cate:330 position:971 position:969 position:969 position:969 position:934 position:934 position:921 position:921 position:921 position:921 position:861 position:794 position:691 position:690 position:689 position:689 position:689 position:686 position:683 position:683 position:681 position:656 position:408 position:341 position:341 position:278 position:276 position:275 position:229 position:226 position:210 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:119 position:110 position:105 target:15142 target_cate:764 label:0 +history:1573 cate:540 position:0 target:18294 target_cate:463 label:1 +history:9837 history:13438 history:13690 cate:351 cate:629 cate:24 position:287 position:287 position:287 target:26044 target_cate:351 label:0 +history:1708 history:2675 history:4935 history:7401 history:14413 history:22177 history:30319 history:32217 history:34342 history:40235 history:42963 history:43949 history:54816 cate:463 cate:115 cate:474 cate:616 cate:474 cate:44 cate:113 cate:279 cate:164 cate:142 cate:616 cate:649 cate:36 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 target:31992 target_cate:115 label:0 +history:8025 history:11769 history:36188 history:42006 cate:142 cate:262 cate:714 cate:142 position:1107 position:1107 position:21 position:13 target:8209 target_cate:142 label:0 +history:30266 cate:176 position:0 target:44167 target_cate:692 label:0 +history:13000 history:14769 history:2940 history:27638 history:23158 cate:765 cate:27 cate:736 cate:554 cate:112 position:1155 position:797 position:348 position:348 position:334 target:55050 target_cate:725 label:0 +history:32557 history:18668 history:43441 cate:765 cate:707 cate:396 position:1 position:0 position:0 target:44217 target_cate:681 label:1 +history:5665 history:5964 history:18874 cate:542 cate:746 cate:196 position:1229 position:1202 position:123 target:16747 target_cate:179 label:0 +history:7014 history:29912 history:42468 cate:194 cate:612 cate:558 position:2424 position:0 position:0 target:20800 target_cate:355 label:0 +history:8320 history:9743 history:1735 history:442 history:5216 history:11568 cate:234 cate:251 cate:241 cate:603 cate:476 cate:649 position:211 position:70 position:61 position:34 position:34 position:27 target:32738 target_cate:153 label:0 +history:533 history:1447 cate:744 cate:744 position:664 position:337 target:17843 target_cate:744 label:1 +history:48390 history:48191 cate:714 cate:714 position:137 position:92 target:48864 target_cate:708 label:1 +history:9312 history:16166 history:12754 history:21433 history:28142 history:7486 cate:215 cate:674 cate:241 cate:115 cate:558 cate:241 position:1910 position:1045 position:414 position:371 position:371 position:347 target:38629 target_cate:48 label:1 +history:10401 history:11665 history:10739 cate:142 cate:364 cate:766 position:363 position:217 position:48 target:5989 target_cate:463 label:0 +history:10408 history:14363 history:8807 history:14947 history:24701 history:44676 history:40914 history:12241 history:14906 history:29247 history:32347 history:5834 history:18291 history:18313 history:23375 history:24075 history:7020 history:14307 history:15891 cate:140 cate:140 cate:749 cate:281 cate:444 cate:388 cate:504 cate:385 cate:196 cate:339 cate:746 cate:351 cate:463 cate:746 cate:197 cate:90 cate:746 cate:576 cate:476 position:1338 position:1336 position:1305 position:1267 position:835 position:88 position:87 position:86 position:86 position:86 position:86 position:84 position:84 position:84 position:84 position:84 position:83 position:83 position:83 target:37949 target_cate:330 label:1 +history:50194 cate:444 position:243 target:15572 target_cate:216 label:0 +history:24021 cate:281 position:718 target:25850 target_cate:140 label:1 +history:22185 history:28726 history:55777 cate:142 cate:766 cate:351 position:923 position:923 position:133 target:17 target_cate:541 label:1 +history:31776 history:34767 history:28854 history:34769 history:38022 history:38667 history:32917 history:9094 history:40879 history:41634 history:42252 history:19865 history:47983 history:38818 history:40131 history:40690 history:18915 history:48539 history:49619 history:18554 history:24836 cate:70 cate:239 cate:113 cate:48 cate:486 cate:541 cate:352 cate:197 cate:347 cate:385 cate:34 cate:476 cate:704 cate:388 cate:385 cate:281 cate:225 cate:474 cate:157 cate:706 cate:53 position:490 position:490 position:473 position:360 position:360 position:360 position:209 position:199 position:199 position:199 position:199 position:198 position:198 position:196 position:196 position:174 position:93 position:36 position:36 position:0 position:0 target:25602 target_cate:707 label:1 +history:10544 history:15159 history:23606 history:33556 history:46886 history:55061 history:2079 history:27022 history:40345 history:43556 history:3807 history:28732 cate:642 cate:87 cate:641 cate:113 cate:558 cate:157 cate:564 cate:44 cate:194 cate:26 cate:54 cate:113 position:844 position:362 position:362 position:362 position:362 position:362 position:205 position:205 position:205 position:205 position:0 position:0 target:51293 target_cate:272 label:0 +history:19005 history:41469 history:42368 history:5739 history:30169 history:32266 history:54743 history:56959 history:26271 cate:145 cate:482 cate:707 cate:790 cate:101 cate:347 cate:197 cate:368 cate:674 position:365 position:365 position:365 position:258 position:258 position:258 position:258 position:258 position:0 target:5602 target_cate:158 label:0 +history:7166 history:16886 history:21083 history:7328 history:25545 cate:560 cate:213 cate:87 cate:744 cate:87 position:474 position:474 position:474 position:214 position:214 target:32494 target_cate:321 label:1 +history:2306 cate:260 position:51 target:30286 target_cate:179 label:0 +history:57709 history:55115 cate:351 cate:483 position:99 position:50 target:25035 target_cate:142 label:0 +history:16641 history:35845 cate:153 cate:311 position:0 position:0 target:36985 target_cate:68 label:1 +history:31144 history:4107 cate:189 cate:168 position:1179 position:0 target:50619 target_cate:142 label:0 +history:36331 history:9873 history:10659 history:14382 history:21430 history:28164 cate:680 cate:197 cate:185 cate:11 cate:115 cate:476 position:278 position:0 position:0 position:0 position:0 position:0 target:37887 target_cate:484 label:1 +history:19519 history:3748 history:33772 history:22436 history:38789 history:46337 cate:649 cate:351 cate:210 cate:115 cate:113 cate:115 position:1038 position:517 position:470 position:349 position:150 position:37 target:23980 target_cate:649 label:1 +history:30789 history:37586 history:42354 history:26171 history:15017 history:28654 history:44960 cate:142 cate:714 cate:142 cate:483 cate:484 cate:474 cate:157 position:158 position:158 position:146 position:36 position:26 position:26 position:26 target:41552 target_cate:746 label:1 +history:52662 cate:576 position:0 target:53627 target_cate:776 label:0 +history:12258 history:15133 history:15681 history:5066 history:6420 history:13421 history:6577 history:29202 history:38939 cate:216 cate:558 cate:111 cate:570 cate:447 cate:5 cate:111 cate:281 cate:347 position:1544 position:1359 position:1312 position:743 position:743 position:636 position:560 position:103 position:24 target:7818 target_cate:558 label:0 +history:610 history:1258 history:2332 history:7508 history:10814 history:10797 history:11710 cate:543 cate:611 cate:611 cate:653 cate:110 cate:201 cate:179 position:2452 position:1361 position:935 position:669 position:524 position:55 position:45 target:11495 target_cate:558 label:1 +history:12584 history:2707 history:1664 history:25878 history:25949 cate:790 cate:694 cate:694 cate:142 cate:611 position:768 position:729 position:625 position:236 position:7 target:25286 target_cate:792 label:1 +history:32423 history:24223 cate:135 cate:90 position:421 position:76 target:2323 target_cate:399 label:0 +history:11959 cate:197 position:0 target:15349 target_cate:351 label:1 +history:44448 history:58138 history:41930 history:57603 history:59009 history:61316 history:61559 history:599 cate:339 cate:629 cate:115 cate:388 cate:1 cate:142 cate:714 cate:297 position:320 position:97 position:23 position:23 position:23 position:23 position:23 position:0 target:54434 target_cate:142 label:0 +history:43441 history:12617 history:47970 history:52144 cate:396 cate:196 cate:142 cate:629 position:213 position:208 position:208 position:208 target:29211 target_cate:351 label:1 +history:25327 history:40258 cate:656 cate:398 position:676 position:3 target:40261 target_cate:142 label:1 +history:4637 cate:474 position:62 target:59864 target_cate:687 label:0 diff --git a/models/rank/BST/model.py b/models/rank/BST/model.py new file mode 100755 index 00000000..101cb792 --- /dev/null +++ b/models/rank/BST/model.py @@ -0,0 +1,347 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial + +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.layers as layers + +from paddlerec.core.utils import envs +from paddlerec.core.model import ModelBase + + +def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate): + """ + Position-wise Feed-Forward Networks. + This module consists of two linear transformations with a ReLU activation + in between, which is applied to each position separately and identically. + """ + hidden = layers.fc(input=x, + size=d_inner_hid, + num_flatten_dims=2, + act="relu") + if dropout_rate: + hidden = layers.dropout( + hidden, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False) + out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) + return out + + +def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.): + """ + Add residual connection, layer normalization and droput to the out tensor + optionally according to the value of process_cmd. + This will be used before or after multi-head attention and position-wise + feed-forward networks. + """ + for cmd in process_cmd: + if cmd == "a": # add residual connection + out = out + prev_out if prev_out else out + elif cmd == "n": # add layer normalization + out = layers.layer_norm( + out, + begin_norm_axis=len(out.shape) - 1, + param_attr=fluid.initializer.Constant(1.), + bias_attr=fluid.initializer.Constant(0.)) + elif cmd == "d": # add dropout + if dropout_rate: + out = layers.dropout( + out, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False) + return out + + +pre_process_layer = partial(pre_post_process_layer, None) +post_process_layer = pre_post_process_layer + + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def _init_hyper_parameters(self): + self.item_emb_size = envs.get_global_env( + "hyper_parameters.item_emb_size", 64) + self.cat_emb_size = envs.get_global_env( + "hyper_parameters.cat_emb_size", 64) + self.position_emb_size = envs.get_global_env( + "hyper_parameters.position_emb_size", 64) + self.act = envs.get_global_env("hyper_parameters.act", "sigmoid") + self.is_sparse = envs.get_global_env("hyper_parameters.is_sparse", + False) + # significant for speeding up the training process + self.use_DataLoader = envs.get_global_env( + "hyper_parameters.use_DataLoader", False) + self.item_count = envs.get_global_env("hyper_parameters.item_count", + 63001) + self.cat_count = envs.get_global_env("hyper_parameters.cat_count", 801) + self.position_count = envs.get_global_env( + "hyper_parameters.position_count", 5001) + self.n_encoder_layers = envs.get_global_env( + "hyper_parameters.n_encoder_layers", 1) + self.d_model = envs.get_global_env("hyper_parameters.d_model", 96) + self.d_key = envs.get_global_env("hyper_parameters.d_key", None) + self.d_value = envs.get_global_env("hyper_parameters.d_value", None) + self.n_head = envs.get_global_env("hyper_parameters.n_head", None) + self.dropout_rate = envs.get_global_env( + "hyper_parameters.dropout_rate", 0.0) + self.postprocess_cmd = envs.get_global_env( + "hyper_parameters.postprocess_cmd", "da") + self.preprocess_cmd = envs.get_global_env( + "hyper_parameters.postprocess_cmd", "n") + self.prepostprocess_dropout = envs.get_global_env( + "hyper_parameters.prepostprocess_dropout", 0.0) + self.d_inner_hid = envs.get_global_env("hyper_parameters.d_inner_hid", + 512) + self.relu_dropout = envs.get_global_env( + "hyper_parameters.relu_dropout", 0.0) + self.layer_sizes = envs.get_global_env("hyper_parameters.fc_sizes", + None) + + def multi_head_attention(self, queries, keys, values, d_key, d_value, + d_model, n_head, dropout_rate): + keys = queries if keys is None else keys + values = keys if values is None else values + if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3 + ): + raise ValueError( + "Inputs: quries, keys and values should all be 3-D tensors.") + + def __compute_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Add linear projection to queries, keys, and values. + """ + q = fluid.layers.fc(input=queries, + size=d_key * n_head, + bias_attr=False, + num_flatten_dims=2) + k = fluid.layers.fc(input=keys, + size=d_key * n_head, + bias_attr=False, + num_flatten_dims=2) + v = fluid.layers.fc(input=values, + size=d_value * n_head, + bias_attr=False, + num_flatten_dims=2) + return q, k, v + + def __split_heads_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Reshape input tensors at the last dimension to split multi-heads + and then transpose. Specifically, transform the input tensor with shape + [bs, max_sequence_length, n_head * hidden_dim] to the output tensor + with shape [bs, n_head, max_sequence_length, hidden_dim]. + """ + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + reshaped_q = fluid.layers.reshape( + x=queries, shape=[0, 0, n_head, d_key], inplace=True) + # permuate the dimensions into: + # [batch_size, n_head, max_sequence_len, hidden_size_per_head] + q = fluid.layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3]) + # For encoder-decoder attention in inference, insert the ops and vars + # into global block to use as cache among beam search. + reshaped_k = fluid.layers.reshape( + x=keys, shape=[0, 0, n_head, d_key], inplace=True) + k = fluid.layers.transpose(x=reshaped_k, perm=[0, 2, 1, 3]) + reshaped_v = fluid.layers.reshape( + x=values, shape=[0, 0, n_head, d_value], inplace=True) + v = fluid.layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3]) + + return q, k, v + + def scaled_dot_product_attention(q, k, v, d_key, dropout_rate): + """ + Scaled Dot-Product Attention + """ + product = fluid.layers.matmul( + x=q, y=k, transpose_y=True, alpha=d_key**-0.5) + + weights = fluid.layers.softmax(product) + if dropout_rate: + weights = fluid.layers.dropout( + weights, + dropout_prob=dropout_rate, + seed=None, + is_test=False) + out = fluid.layers.matmul(weights, v) + return out + + def __combine_heads(x): + """ + Transpose and then reshape the last two dimensions of inpunt tensor x + so that it becomes one dimension, which is reverse to __split_heads. + """ + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + + trans_x = fluid.layers.transpose(x, perm=[0, 2, 1, 3]) + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + return fluid.layers.reshape( + x=trans_x, + shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]], + inplace=True) + + q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value) + + ctx_multiheads = scaled_dot_product_attention(q, k, v, d_model, + dropout_rate) + + out = __combine_heads(ctx_multiheads) + + proj_out = fluid.layers.fc(input=out, + size=d_model, + bias_attr=False, + num_flatten_dims=2) + + return proj_out + + def encoder_layer(self, x): + attention_out = self.multi_head_attention( + pre_process_layer(x, self.preprocess_cmd, + self.prepostprocess_dropout), None, None, + self.d_key, self.d_value, self.d_model, self.n_head, + self.dropout_rate) + attn_output = post_process_layer(x, attention_out, + self.postprocess_cmd, + self.prepostprocess_dropout) + ffd_output = positionwise_feed_forward( + pre_process_layer(attn_output, self.preprocess_cmd, + self.prepostprocess_dropout), self.d_inner_hid, + self.d_model, self.relu_dropout) + return post_process_layer(attn_output, ffd_output, + self.postprocess_cmd, + self.prepostprocess_dropout) + + def net(self, inputs, is_infer=False): + + init_value_ = 0.1 + + hist_item_seq = self._sparse_data_var[1] + hist_cat_seq = self._sparse_data_var[2] + position_seq = self._sparse_data_var[3] + target_item = self._sparse_data_var[4] + target_cat = self._sparse_data_var[5] + target_position = self._sparse_data_var[6] + self.label = self._sparse_data_var[0] + + item_emb_attr = fluid.ParamAttr(name="item_emb") + cat_emb_attr = fluid.ParamAttr(name="cat_emb") + position_emb_attr = fluid.ParamAttr(name="position_emb") + + hist_item_emb = fluid.embedding( + input=hist_item_seq, + size=[self.item_count, self.item_emb_size], + param_attr=item_emb_attr, + is_sparse=self.is_sparse) + + hist_cat_emb = fluid.embedding( + input=hist_cat_seq, + size=[self.cat_count, self.cat_emb_size], + param_attr=cat_emb_attr, + is_sparse=self.is_sparse) + + hist_position_emb = fluid.embedding( + input=hist_cat_seq, + size=[self.position_count, self.position_emb_size], + param_attr=position_emb_attr, + is_sparse=self.is_sparse) + + target_item_emb = fluid.embedding( + input=target_item, + size=[self.item_count, self.item_emb_size], + param_attr=item_emb_attr, + is_sparse=self.is_sparse) + + target_cat_emb = fluid.embedding( + input=target_cat, + size=[self.cat_count, self.cat_emb_size], + param_attr=cat_emb_attr, + is_sparse=self.is_sparse) + + target_position_emb = fluid.embedding( + input=target_position, + size=[self.position_count, self.position_emb_size], + param_attr=position_emb_attr, + is_sparse=self.is_sparse) + + item_sequence_target = fluid.layers.reduce_sum( + fluid.layers.sequence_concat([hist_item_emb, target_item_emb]), + dim=1) + cat_sequence_target = fluid.layers.reduce_sum( + fluid.layers.sequence_concat([hist_cat_emb, target_cat_emb]), + dim=1) + position_sequence_target = fluid.layers.reduce_sum( + fluid.layers.sequence_concat( + [hist_position_emb, target_position_emb]), + dim=1) + + whole_embedding_withlod = fluid.layers.concat( + [ + item_sequence_target, cat_sequence_target, + position_sequence_target + ], + axis=1) + pad_value = fluid.layers.assign(input=np.array( + [0.0], dtype=np.float32)) + whole_embedding, _ = fluid.layers.sequence_pad(whole_embedding_withlod, + pad_value) + + for _ in range(self.n_encoder_layers): + enc_output = self.encoder_layer(whole_embedding) + enc_input = enc_output + enc_output = pre_process_layer(enc_output, self.preprocess_cmd, + self.prepostprocess_dropout) + + dnn_input = fluid.layers.reduce_sum(enc_output, dim=1) + + for s in self.layer_sizes: + dnn_input = fluid.layers.fc( + input=dnn_input, + size=s, + act=self.act, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_ / math.sqrt(float(10)))), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_))) + + y_dnn = fluid.layers.fc(input=dnn_input, size=1, act=None) + + self.predict = fluid.layers.sigmoid(y_dnn) + cost = fluid.layers.log_loss( + input=self.predict, label=fluid.layers.cast(self.label, "float32")) + avg_cost = fluid.layers.reduce_sum(cost) + + self._cost = avg_cost + + predict_2d = fluid.layers.concat([1 - self.predict, self.predict], 1) + label_int = fluid.layers.cast(self.label, 'int64') + auc_var, batch_auc_var, _ = fluid.layers.auc(input=predict_2d, + label=label_int, + slide_steps=0) + self._metrics["AUC"] = auc_var + self._metrics["BATCH_AUC"] = batch_auc_var + if is_infer: + self._infer_results["AUC"] = auc_var diff --git a/models/rank/readme.md b/models/rank/readme.md index da242481..cb34dcd6 100644 --- a/models/rank/readme.md +++ b/models/rank/readme.md @@ -37,6 +37,7 @@ | xDeepFM | xDeepFM | [xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/3219819.3220023)(2018) | | DIN | Deep Interest Network | [Deep Interest Network for Click-Through Rate Prediction](https://dl.acm.org/doi/pdf/10.1145/3219819.3219823)(2018) | | DIEN | Deep Interest Evolution Network | [Deep Interest Evolution Network for Click-Through Rate Prediction](https://www.aaai.org/ojs/index.php/AAAI/article/view/4545/4423)(2019) | +| BST | transformer in user behavior sequence for rank | [Behavior Sequence Transformer for E-commerce Recommendation in Alibaba](https://arxiv.org/pdf/1905.06874v1.pdf)(2019) | | FGCNN | Feature Generation by CNN | [Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1904.04447.pdf)(2019) | | FIBINET | Combining Feature Importance and Bilinear feature Interaction | [《FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction》]( https://arxiv.org/pdf/1905.09433.pdf)(2019) | | FLEN | Leveraging Field for Scalable CTR Prediction | [《FLEN: Leveraging Field for Scalable CTR Prediction》]( https://arxiv.org/pdf/1911.04690.pdf)(2019) | -- GitLab