未验证 提交 b8e17866 编写于 作者: Y yaoxuefeng 提交者: GitHub

add rank model BST (#134)

* add rank model BST

* update readme
上级 5fd7f899
......@@ -56,6 +56,7 @@
| Rank | [xDeepFM](models/rank/xdeepfm/model.py) | ✓ | x | ✓ | x | [KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/3219819.3220023) |
| Rank | [DIN](models/rank/din/model.py) | ✓ | x | ✓ | x | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://dl.acm.org/doi/pdf/10.1145/3219819.3219823) |
| Rank | [DIEN](models/rank/dien/model.py) | ✓ | x | ✓ | x | [AAAI 2019][Deep Interest Evolution Network for Click-Through Rate Prediction](https://www.aaai.org/ojs/index.php/AAAI/article/view/4545/4423) |
| Rank | [BST](models/rank/BST/model.py) | ✓ | x | ✓ | x | [DLP-KDD 2019][Behavior Sequence Transformer for E-commerce Recommendation in Alibaba](https://arxiv.org/pdf/1905.06874v1.pdf) |
| Rank | [AutoInt](models/rank/AutoInt/model.py) | ✓ | x | ✓ | x | [CIKM 2019][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/pdf/1810.11921.pdf) |
| Rank | [Wide&Deep](models/rank/wide_deep/model.py) | ✓ | x | ✓ | x | [DLRS 2016][Wide & Deep Learning for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/2988450.2988454) |
| Rank | [FGCNN](models/rank/fgcnn/model.py) | ✓ | ✓ | ✓ | ✓ | [WWW 2019][Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1904.04447.pdf) |
......
......@@ -61,6 +61,7 @@
| 排序 | [xDeepFM](models/rank/xdeepfm/model.py) | ✓ | x | ✓ | x | [KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/3219819.3220023) |
| 排序 | [DIN](models/rank/din/model.py) | ✓ | x | ✓ | x | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://dl.acm.org/doi/pdf/10.1145/3219819.3219823) |
| 排序 | [DIEN](models/rank/dien/model.py) | ✓ | x | ✓ | x | [AAAI 2019][Deep Interest Evolution Network for Click-Through Rate Prediction](https://www.aaai.org/ojs/index.php/AAAI/article/view/4545/4423) |
| 排序 | [BST](models/rank/BST/model.py) | ✓ | x | ✓ | x | [DLP_KDD 2019][Behavior Sequence Transformer for E-commerce Recommendation in Alibaba](https://arxiv.org/pdf/1905.06874v1.pdf) |
| 排序 | [AutoInt](models/rank/AutoInt/model.py) | ✓ | x | ✓ | x | [CIKM 2019][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/pdf/1810.11921.pdf) |
| 排序 | [Wide&Deep](models/rank/wide_deep/model.py) | ✓ | x | ✓ | x | [DLRS 2016][Wide & Deep Learning for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/2988450.2988454) |
| 排序 | [FGCNN](models/rank/fgcnn/model.py) | ✓ | ✓ | ✓ | ✓ | [WWW 2019][Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1904.04447.pdf) |
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# global settings
debug: false
workspace: "paddlerec.models.rank.BST"
dataset:
- name: sample_1
type: DataLoader
batch_size: 5
data_path: "{workspace}/data/train_data"
sparse_slots: "label history cate position target target_cate target_position"
- name: infer_sample
type: DataLoader
batch_size: 5
data_path: "{workspace}/data/train_data"
sparse_slots: "label history cate position target target_cate target_position"
hyper_parameters:
optimizer:
class: SGD
learning_rate: 0.0001
use_DataLoader: True
item_emb_size: 96
cat_emb_size: 96
position_emb_size: 96
is_sparse: False
item_count: 63001
cat_count: 801
position_count: 5001
n_encoder_layers: 1
d_model: 288
d_key: 48
d_value: 48
n_head: 6
dropout_rate: 0
postprocess_cmd: "da"
prepostprocess_dropout: 0
d_inner_hid: 512
relu_dropout: 0.0
act: "relu"
fc_sizes: [1024, 512, 256]
mode: train_runner
runner:
- name: train_runner
class: train
epochs: 1
device: cpu
init_model_path: ""
save_checkpoint_interval: 1
save_inference_interval: 1
save_checkpoint_path: "increment_BST"
save_inference_path: "inference_BST"
print_interval: 1
- name: infer_runner
class: infer
device: cpu
init_model_path: "increment_BST/0"
print_interval: 1
phase:
- name: phase1
model: "{workspace}/model.py"
dataset_name: sample_1
thread_num: 1
#- name: infer_phase
# model: "{workspace}/model.py"
# dataset_name: infer_sample
# thread_num: 1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import random
import pickle
random.seed(1234)
print("read and process data")
with open('./raw_data/remap.pkl', 'rb') as f:
reviews_df = pickle.load(f)
cate_list = pickle.load(f)
user_count, item_count, cate_count, example_count = pickle.load(f)
train_set = []
test_set = []
for reviewerID, hist in reviews_df.groupby('reviewerID'):
pos_list = hist['asin'].tolist()
time_list = hist['unixReviewTime'].tolist()
def gen_neg():
neg = pos_list[0]
while neg in pos_list:
neg = random.randint(0, item_count - 1)
return neg
neg_list = [gen_neg() for i in range(len(pos_list))]
for i in range(1, len(pos_list)):
hist = pos_list[:i]
# set maximum position value
time_seq = [
min(int((time_list[i] - time_list[j]) / (3600 * 24)), 5000)
for j in range(i)
]
if i != len(pos_list) - 1:
train_set.append((reviewerID, hist, pos_list[i], 1, time_seq))
train_set.append((reviewerID, hist, neg_list[i], 0, time_seq))
else:
label = (pos_list[i], neg_list[i])
test_set.append((reviewerID, hist, label, time_seq))
random.shuffle(train_set)
random.shuffle(test_set)
assert len(test_set) == user_count
def print_to_file(data, fout, slot):
if not isinstance(data, list):
data = [data]
for i in range(len(data)):
fout.write(slot + ":" + str(data[i]))
fout.write(' ')
print("make train data")
with open("paddle_train.txt", "w") as fout:
for line in train_set:
history = line[1]
target = line[2]
label = line[3]
position = line[4]
cate = [cate_list[x] for x in history]
print_to_file(history, fout, "history")
print_to_file(cate, fout, "cate")
print_to_file(position, fout, "position")
print_to_file(target, fout, "target")
print_to_file(cate_list[target], fout, "target_cate")
print_to_file(0, fout, "target_position")
print_to_file(label, fout, "label")
fout.write("\n")
print("make test data")
with open("paddle_test.txt", "w") as fout:
for line in test_set:
history = line[1]
target = line[2]
position = line[3]
cate = [cate_list[x] for x in history]
print_to_file(history, fout, "history")
print_to_file(cate, fout, "cate")
print_to_file(position, fout, "position")
print_to_file(target[0], fout, "target")
print_to_file(cate_list[target[0]], fout, "target_cate")
print_to_file(0, fout, "target_position")
fout.write("label:1\n")
print_to_file(history, fout, "history")
print_to_file(cate, fout, "cate")
print_to_file(position, fout, "position")
print_to_file(target[0], fout, "target")
print_to_file(cate_list[target[1]], fout, "target_cate")
print_to_file(0, fout, "target_position")
fout.write("label:0\n")
print("make config data")
with open('config.txt', 'w') as f:
f.write(str(user_count) + "\n")
f.write(str(item_count) + "\n")
f.write(str(cate_count) + "\n")
f.wrire(str(50000) + "\n")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import pickle
import pandas as pd
def to_df(file_path):
with open(file_path, 'r') as fin:
df = {}
i = 0
for line in fin:
df[i] = eval(line)
i += 1
df = pd.DataFrame.from_dict(df, orient='index')
return df
print("start to analyse reviews_Electronics_5.json")
reviews_df = to_df('./raw_data/reviews_Electronics_5.json')
with open('./raw_data/reviews.pkl', 'wb') as f:
pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)
print("start to analyse meta_Electronics.json")
meta_df = to_df('./raw_data/meta_Electronics.json')
meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())]
meta_df = meta_df.reset_index(drop=True)
with open('./raw_data/meta.pkl', 'wb') as f:
pickle.dump(meta_df, f, pickle.HIGHEST_PROTOCOL)
#! /bin/bash
set -e
echo "begin download data"
mkdir raw_data
cd raw_data
wget -c http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics_5.json.gz
gzip -d reviews_Electronics_5.json.gz
wget -c http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Electronics.json.gz
gzip -d meta_Electronics.json.gz
echo "download data successfully"
cd ..
python convert_pd.py
python remap_id.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import random
import pickle
import numpy as np
random.seed(1234)
with open('./raw_data/reviews.pkl', 'rb') as f:
reviews_df = pickle.load(f)
reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]
with open('./raw_data/meta.pkl', 'rb') as f:
meta_df = pickle.load(f)
meta_df = meta_df[['asin', 'categories']]
meta_df['categories'] = meta_df['categories'].map(lambda x: x[-1][-1])
def build_map(df, col_name):
key = sorted(df[col_name].unique().tolist())
m = dict(zip(key, range(len(key))))
df[col_name] = df[col_name].map(lambda x: m[x])
return m, key
asin_map, asin_key = build_map(meta_df, 'asin')
cate_map, cate_key = build_map(meta_df, 'categories')
revi_map, revi_key = build_map(reviews_df, 'reviewerID')
user_count, item_count, cate_count, example_count =\
len(revi_map), len(asin_map), len(cate_map), reviews_df.shape[0]
print('user_count: %d\titem_count: %d\tcate_count: %d\texample_count: %d' %
(user_count, item_count, cate_count, example_count))
meta_df = meta_df.sort_values('asin')
meta_df = meta_df.reset_index(drop=True)
reviews_df['asin'] = reviews_df['asin'].map(lambda x: asin_map[x])
reviews_df = reviews_df.sort_values(['reviewerID', 'unixReviewTime'])
reviews_df = reviews_df.reset_index(drop=True)
reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]
cate_list = [meta_df['categories'][i] for i in range(len(asin_map))]
cate_list = np.array(cate_list, dtype=np.int32)
with open('./raw_data/remap.pkl', 'wb') as f:
pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL) # uid, iid
pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL) # cid of iid line
pickle.dump((user_count, item_count, cate_count, example_count), f,
pickle.HIGHEST_PROTOCOL)
pickle.dump((asin_key, cate_key, revi_key), f, pickle.HIGHEST_PROTOCOL)
history:3737 history:19450 cate:288 cate:196 position:518 position:158 target:18486 target_cate:674 label:1
history:3647 history:4342 history:6855 history:3805 cate:281 cate:463 cate:558 cate:674 position:242 position:216 position:17 position:5 target:4206 target_cate:463 label:1
history:1805 history:4309 cate:87 cate:87 position:61 position:0 target:21354 target_cate:556 label:1
history:18209 history:20753 cate:649 cate:241 position:0 position:0 target:51924 target_cate:610 label:0
history:13150 cate:351 position:505 target:41455 target_cate:792 label:1
history:35120 history:40418 cate:157 cate:714 position:0 position:0 target:52035 target_cate:724 label:0
history:13515 history:20363 history:25356 history:26891 history:24200 history:11694 history:33378 history:34483 history:35370 history:27311 history:40689 history:33319 history:28819 cate:558 cate:123 cate:61 cate:110 cate:738 cate:692 cate:110 cate:629 cate:714 cate:463 cate:281 cate:142 cate:382 position:1612 position:991 position:815 position:668 position:639 position:508 position:456 position:431 position:409 position:222 position:221 position:74 position:34 target:45554 target_cate:558 label:1
history:19254 history:9021 history:28156 history:19193 history:24602 history:31171 cate:189 cate:462 cate:140 cate:474 cate:157 cate:614 position:375 position:144 position:141 position:0 position:0 position:0 target:48895 target_cate:350 label:1
history:4716 cate:194 position:2457 target:32497 target_cate:484 label:1
history:43799 history:47108 cate:368 cate:140 position:181 position:105 target:3503 target_cate:25 label:0
history:20554 history:41800 history:1582 history:1951 cate:339 cate:776 cate:694 cate:703 position:35 position:35 position:0 position:0 target:4320 target_cate:234 label:0
history:39713 history:44272 history:45136 history:11687 cate:339 cate:339 cate:339 cate:140 position:40 position:40 position:40 position:0 target:885 target_cate:168 label:0
history:14398 history:33997 cate:756 cate:347 position:73 position:73 target:20438 target_cate:703 label:1
history:29341 history:25727 cate:142 cate:616 position:839 position:0 target:4170 target_cate:512 label:0
history:12197 history:10212 cate:558 cate:694 position:1253 position:677 target:31559 target_cate:24 label:0
history:11551 cate:351 position:47 target:53485 target_cate:436 label:1
history:4553 cate:196 position:88 target:7331 target_cate:158 label:1
history:15190 history:19994 history:33946 history:30716 history:31879 history:45178 history:51598 history:46814 cate:249 cate:498 cate:612 cate:142 cate:746 cate:746 cate:558 cate:174 position:1912 position:1275 position:1170 position:1122 position:773 position:773 position:329 position:291 target:24353 target_cate:251 label:0
history:4931 history:2200 history:8338 history:23530 cate:785 cate:792 cate:277 cate:523 position:1360 position:975 position:975 position:586 target:3525 target_cate:251 label:0
history:8881 history:13274 history:12683 history:14696 history:27693 history:1395 history:44373 history:59704 history:27762 history:54268 history:30326 history:11811 history:45371 history:51598 history:55859 history:56039 history:57678 history:47250 history:2073 history:38932 cate:479 cate:558 cate:190 cate:708 cate:335 cate:684 cate:339 cate:725 cate:446 cate:446 cate:44 cate:575 cate:280 cate:558 cate:262 cate:197 cate:368 cate:111 cate:749 cate:188 position:2065 position:2065 position:1292 position:1108 position:647 position:343 position:343 position:343 position:257 position:257 position:143 position:76 position:76 position:76 position:76 position:76 position:76 position:58 position:6 position:6 target:12361 target_cate:616 label:1
history:16297 history:16797 history:18629 history:20922 history:16727 history:33946 history:51165 history:36796 cate:281 cate:436 cate:462 cate:339 cate:611 cate:612 cate:288 cate:64 position:1324 position:1324 position:1324 position:1118 position:183 position:133 position:6 position:4 target:34724 target_cate:288 label:1
history:22237 cate:188 position:339 target:40786 target_cate:637 label:0
history:5396 history:39993 history:42681 history:49832 history:11208 history:34954 history:36523 history:45523 history:51618 cate:351 cate:339 cate:687 cate:281 cate:708 cate:142 cate:629 cate:656 cate:142 position:1117 position:290 position:276 position:191 position:144 position:144 position:120 position:66 position:66 target:38201 target_cate:571 label:0
history:8881 history:9029 history:17043 history:16620 history:15021 history:32706 cate:479 cate:110 cate:110 cate:749 cate:598 cate:251 position:1218 position:1218 position:790 position:695 position:264 position:1 target:34941 target_cate:657 label:0
history:53255 cate:444 position:232 target:37953 target_cate:724 label:1
history:1010 history:4172 history:8613 history:11562 history:11709 history:13118 history:2027 history:15446 cate:674 cate:606 cate:708 cate:436 cate:179 cate:179 cate:692 cate:436 position:324 position:323 position:323 position:323 position:323 position:308 position:307 position:307 target:36998 target_cate:703 label:0
history:22357 history:24305 history:15222 history:19254 history:22914 cate:189 cate:504 cate:113 cate:189 cate:714 position:321 position:321 position:232 position:232 position:232 target:18201 target_cate:398 label:1
history:1905 cate:694 position:0 target:23877 target_cate:347 label:1
history:8444 history:17868 cate:765 cate:712 position:454 position:0 target:50732 target_cate:44 label:0
history:42301 history:26186 history:38086 cate:142 cate:450 cate:744 position:164 position:0 position:0 target:61547 target_cate:714 label:0
history:18156 history:35717 history:32070 history:45650 history:47208 history:20975 history:36409 history:44856 history:48072 history:15860 history:47043 history:53289 history:53314 history:33470 history:47926 cate:157 cate:281 cate:650 cate:142 cate:749 cate:291 cate:707 cate:714 cate:157 cate:205 cate:388 cate:474 cate:708 cate:498 cate:495 position:546 position:506 position:296 position:296 position:263 position:253 position:253 position:221 position:121 position:26 position:26 position:26 position:26 position:0 position:0 target:48170 target_cate:746 label:1
history:56219 cate:108 position:0 target:1988 target_cate:389 label:0
history:22907 cate:83 position:353 target:752 target_cate:175 label:0
history:22009 history:32410 history:42987 history:48720 history:683 history:1289 history:2731 history:4736 history:6306 history:8442 history:8946 history:9928 history:11536 history:14947 history:15793 history:16694 history:21736 history:25156 history:25797 history:25874 history:26573 history:30318 history:33946 history:35420 history:1492 history:5236 history:5555 history:6625 history:8867 history:9638 history:11443 history:20225 history:25965 history:27273 history:29001 history:35302 history:42336 history:43347 history:36907 history:2012 cate:317 cate:462 cate:291 cate:142 cate:694 cate:10 cate:574 cate:278 cate:708 cate:281 cate:131 cate:142 cate:367 cate:281 cate:258 cate:345 cate:616 cate:708 cate:111 cate:115 cate:339 cate:113 cate:612 cate:24 cate:368 cate:616 cate:39 cate:197 cate:44 cate:214 cate:558 cate:108 cate:616 cate:558 cate:210 cate:210 cate:142 cate:142 cate:262 cate:351 position:390 position:390 position:390 position:390 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:389 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:383 position:366 position:339 position:333 position:167 target:25540 target_cate:701 label:0
history:20434 cate:196 position:610 target:18056 target_cate:189 label:0
history:628 history:5461 cate:194 cate:234 position:294 position:74 target:43677 target_cate:351 label:0
history:16953 history:15149 history:45143 history:23587 history:5094 history:25105 history:51913 history:54645 cate:484 cate:281 cate:449 cate:792 cate:524 cate:395 cate:388 cate:731 position:1134 position:668 position:626 position:409 position:285 position:285 position:285 position:42 target:57655 target_cate:75 label:1
history:13584 history:7509 cate:234 cate:744 position:1187 position:231 target:33062 target_cate:749 label:1
history:170 history:208 history:77 history:109 history:738 history:742 history:1118 history:15349 history:255 history:12067 history:21643 history:55453 cate:330 cate:559 cate:744 cate:115 cate:558 cate:674 cate:111 cate:351 cate:694 cate:694 cate:746 cate:111 position:4920 position:4726 position:4585 position:4585 position:4585 position:4584 position:4108 position:1418 position:1326 position:274 position:89 position:88 target:9821 target_cate:694 label:1
history:4970 history:16672 cate:540 cate:746 position:416 position:120 target:25685 target_cate:666 label:1
history:17240 history:60546 cate:708 cate:629 position:165 position:41 target:42110 target_cate:142 label:1
history:31503 history:31226 history:50628 history:22444 cate:142 cate:156 cate:142 cate:203 position:187 position:162 position:109 position:0 target:47812 target_cate:749 label:0
history:2443 history:1763 history:3403 history:4225 history:8951 cate:25 cate:707 cate:351 cate:177 cate:351 position:1397 position:1113 position:973 position:637 position:254 target:7954 target_cate:351 label:1
history:3748 cate:351 position:1086 target:9171 target_cate:657 label:1
history:1755 history:26204 history:42716 history:32991 cate:446 cate:188 cate:497 cate:746 position:440 position:184 position:91 position:52 target:23910 target_cate:395 label:1
history:20637 history:27122 cate:558 cate:44 position:1122 position:0 target:19669 target_cate:301 label:0
history:406 history:872 history:306 history:218 history:883 history:1372 history:1705 history:1709 history:7774 history:2376 history:2879 history:2881 history:13329 history:4992 history:13594 history:11106 history:7131 history:8631 history:1736 history:17585 history:2568 history:16896 history:21971 history:10296 history:22361 history:24108 history:23300 history:11793 history:25351 history:2648 history:24593 history:12692 history:23883 history:25345 history:27129 history:26321 history:21627 history:20738 history:17784 history:28785 history:29281 history:28366 history:24723 history:24319 history:12083 history:29882 history:29974 history:30443 history:30428 history:17072 history:9783 history:16700 history:29421 history:32253 history:28830 history:31299 history:28792 history:33931 history:24973 history:33112 history:21717 history:28339 history:23978 history:18649 history:1841 history:17635 history:19696 history:37448 history:20862 history:30492 history:35736 history:37450 history:2633 history:8675 history:17412 history:25960 history:28389 history:31032 history:37157 history:14555 history:4996 history:33388 history:33393 history:36237 history:38946 history:22793 history:24337 history:34963 history:38819 history:41165 history:39551 history:43019 history:15570 history:25129 history:34593 history:38385 history:42915 history:41407 history:29907 history:31289 history:44229 history:24267 history:34975 history:39462 history:33274 history:43251 history:38302 history:35502 history:44056 history:44675 history:45233 history:47690 history:33472 history:50149 history:29409 history:47183 history:49188 history:48192 history:50628 history:24103 history:28313 history:28358 history:38882 history:44330 history:44346 history:2019 history:2484 history:2675 history:26396 history:48143 history:46039 history:47722 history:48559 history:41719 history:41720 history:43920 history:41983 history:51235 history:34964 history:27287 history:51915 history:33586 history:43630 history:47258 history:52137 history:40954 history:35120 history:29572 history:42405 history:53559 history:44900 history:45761 cate:241 cate:558 cate:395 cate:368 cate:498 cate:110 cate:463 cate:611 cate:558 cate:106 cate:10 cate:112 cate:251 cate:241 cate:48 cate:112 cate:601 cate:674 cate:241 cate:347 cate:733 cate:502 cate:194 cate:119 cate:179 cate:179 cate:578 cate:692 cate:281 cate:115 cate:523 cate:113 cate:281 cate:35 cate:765 cate:196 cate:339 cate:115 cate:90 cate:164 cate:790 cate:708 cate:142 cate:115 cate:342 cate:351 cate:391 cate:281 cate:48 cate:119 cate:74 cate:505 cate:606 cate:68 cate:239 cate:687 cate:687 cate:281 cate:110 cate:281 cate:449 cate:351 cate:38 cate:351 cate:164 cate:176 cate:449 cate:115 cate:70 cate:25 cate:687 cate:115 cate:39 cate:756 cate:35 cate:175 cate:704 cate:119 cate:38 cate:53 cate:115 cate:38 cate:38 cate:142 cate:262 cate:188 cate:614 cate:277 cate:388 cate:615 cate:49 cate:738 cate:106 cate:733 cate:486 cate:666 cate:571 cate:385 cate:708 cate:119 cate:331 cate:463 cate:578 cate:288 cate:142 cate:106 cate:611 cate:611 cate:39 cate:523 cate:388 cate:142 cate:726 cate:702 cate:498 cate:61 cate:142 cate:714 cate:142 cate:654 cate:277 cate:733 cate:603 cate:498 cate:299 cate:97 cate:726 cate:115 cate:637 cate:703 cate:558 cate:74 cate:629 cate:142 cate:142 cate:347 cate:629 cate:746 cate:277 cate:8 cate:49 cate:389 cate:629 cate:408 cate:733 cate:345 cate:157 cate:704 cate:115 cate:398 cate:611 cate:239 position:3925 position:3925 position:3909 position:3897 position:3879 position:3644 position:3611 position:3524 position:2264 position:1913 position:1730 position:1730 position:1684 position:1657 position:1643 position:1626 position:1566 position:1430 position:1375 position:1351 position:1298 position:1298 position:1221 position:1217 position:1177 position:1149 position:1142 position:1141 position:1083 position:1079 position:1067 position:1045 position:1031 position:997 position:994 position:993 position:987 position:968 position:946 position:945 position:905 position:904 position:903 position:897 position:856 position:855 position:813 position:813 position:801 position:799 position:798 position:791 position:791 position:767 position:765 position:761 position:756 position:751 position:747 position:730 position:672 position:659 position:652 position:620 position:619 position:619 position:597 position:596 position:582 position:555 position:555 position:532 position:484 position:484 position:484 position:483 position:483 position:468 position:468 position:467 position:454 position:454 position:454 position:441 position:427 position:409 position:409 position:409 position:409 position:409 position:387 position:387 position:381 position:381 position:381 position:360 position:360 position:357 position:355 position:337 position:332 position:317 position:294 position:271 position:213 position:206 position:204 position:202 position:202 position:182 position:182 position:173 position:154 position:142 position:135 position:114 position:110 position:107 position:107 position:95 position:95 position:95 position:95 position:94 position:92 position:90 position:90 position:90 position:90 position:90 position:86 position:86 position:86 position:84 position:84 position:84 position:83 position:83 position:80 position:65 position:51 position:41 position:23 position:23 position:23 position:22 position:18 position:7 position:3 position:3 position:0 position:0 target:49174 target_cate:368 label:0
history:29206 history:60955 cate:351 cate:684 position:32 position:32 target:61590 target_cate:76 label:1
history:8427 history:9692 history:4411 history:3266 history:18234 history:22774 cate:746 cate:281 cate:396 cate:651 cate:446 cate:44 position:1204 position:1129 position:808 position:622 position:134 position:134 target:23393 target_cate:351 label:0
history:13051 history:15844 history:9347 history:21973 history:18365 history:24220 history:28429 history:4799 history:27488 history:21623 history:13870 history:29346 history:27208 history:31075 history:31635 history:28390 history:30777 history:29334 history:33438 history:16469 history:29423 history:29237 history:25527 history:34808 history:37656 history:21324 history:38263 history:6699 history:33167 history:9295 history:40828 history:18894 cate:339 cate:342 cate:657 cate:194 cate:20 cate:466 cate:179 cate:225 cate:436 cate:364 cate:707 cate:115 cate:36 cate:523 cate:351 cate:674 cate:694 cate:391 cate:674 cate:500 cate:342 cate:216 cate:707 cate:345 cate:616 cate:495 cate:436 cate:363 cate:395 cate:189 cate:203 cate:766 position:1400 position:1032 position:849 position:827 position:804 position:469 position:467 position:463 position:460 position:456 position:455 position:451 position:371 position:371 position:371 position:315 position:315 position:314 position:311 position:287 position:282 position:281 position:239 position:105 position:105 position:70 position:70 position:67 position:56 position:45 position:42 position:6 target:56816 target_cate:396 label:0
history:5653 history:18042 history:21137 history:17277 history:23847 history:25109 history:21837 history:17163 history:22786 history:27380 history:20789 history:27737 history:30164 history:36402 history:37166 history:38647 history:31746 history:38915 history:38366 history:11151 history:43757 history:38284 history:29817 history:41717 history:41899 history:43279 history:47539 history:37850 history:39789 history:43817 history:11208 history:53361 history:29247 history:51483 history:39940 history:50917 history:53618 history:44055 history:48997 cate:593 cate:251 cate:616 cate:110 cate:110 cate:110 cate:110 cate:105 cate:436 cate:558 cate:311 cate:142 cate:603 cate:738 cate:398 cate:766 cate:1 cate:351 cate:142 cate:584 cate:674 cate:597 cate:142 cate:483 cate:351 cate:157 cate:373 cate:142 cate:629 cate:39 cate:708 cate:251 cate:339 cate:142 cate:262 cate:1 cate:113 cate:142 cate:462 position:1285 position:1258 position:1252 position:1206 position:1206 position:1206 position:1205 position:1194 position:1187 position:992 position:804 position:791 position:703 position:670 position:640 position:549 position:548 position:542 position:489 position:480 position:479 position:455 position:422 position:393 position:319 position:296 position:274 position:266 position:266 position:266 position:222 position:141 position:127 position:127 position:114 position:88 position:56 position:22 position:8 target:13418 target_cate:558 label:0
history:8719 history:11172 cate:311 cate:217 position:0 position:0 target:11707 target_cate:179 label:1
history:14968 history:8297 history:22914 history:5998 history:20253 history:41425 history:42664 history:46745 history:51179 history:33481 history:46814 history:55135 history:53124 history:61559 cate:463 cate:766 cate:714 cate:486 cate:628 cate:444 cate:281 cate:714 cate:142 cate:242 cate:174 cate:118 cate:714 cate:714 position:2006 position:1413 position:1323 position:1148 position:977 position:777 position:589 position:487 position:486 position:403 position:349 position:297 position:78 position:12 target:61908 target_cate:714 label:1
history:61119 cate:714 position:99 target:22907 target_cate:83 label:0
history:26172 cate:157 position:258 target:54529 target_cate:44 label:0
history:13830 history:10377 history:8193 history:16072 history:13543 history:18741 history:24205 history:18281 history:37272 history:27784 history:16658 history:27884 cate:384 cate:739 cate:558 cate:739 cate:135 cate:347 cate:558 cate:687 cate:498 cate:142 cate:197 cate:746 position:1447 position:1443 position:1380 position:1312 position:936 position:876 position:695 position:523 position:55 position:25 position:24 position:20 target:34463 target_cate:177 label:1
history:20842 history:11756 history:22110 history:30562 history:30697 cate:189 cate:68 cate:483 cate:776 cate:225 position:516 position:55 position:21 position:21 position:21 target:49113 target_cate:483 label:0
history:13646 history:46782 history:54138 cate:142 cate:798 cate:142 position:604 position:346 position:200 target:43698 target_cate:347 label:0
history:36434 cate:241 position:31 target:51537 target_cate:629 label:0
history:44121 history:35325 cate:397 cate:653 position:809 position:0 target:43399 target_cate:397 label:1
history:6438 history:11107 history:20073 history:25026 history:24434 history:35533 history:6318 history:25028 history:28352 history:32359 history:25734 history:26280 history:41466 history:25192 history:1909 history:11753 history:17770 history:24301 history:1728 history:9693 history:36444 history:40256 history:17961 history:36780 history:41093 history:8788 history:439 history:46397 history:46269 history:50462 history:40395 history:437 history:2582 history:4455 history:12361 history:14325 history:22294 history:26153 history:26607 history:29205 history:29878 history:33491 history:38795 history:41585 history:45480 history:51567 history:54245 history:19796 history:52446 cate:356 cate:194 cate:389 cate:89 cate:474 cate:330 cate:347 cate:384 cate:330 cate:90 cate:19 cate:385 cate:177 cate:68 cate:624 cate:68 cate:674 cate:463 cate:624 cate:194 cate:177 cate:389 cate:197 cate:642 cate:239 cate:111 cate:115 cate:113 cate:48 cate:251 cate:554 cate:115 cate:36 cate:163 cate:616 cate:524 cate:84 cate:190 cate:465 cate:398 cate:89 cate:166 cate:113 cate:330 cate:616 cate:449 cate:90 cate:140 cate:330 position:971 position:969 position:969 position:969 position:934 position:934 position:921 position:921 position:921 position:921 position:861 position:794 position:691 position:690 position:689 position:689 position:689 position:686 position:683 position:683 position:681 position:656 position:408 position:341 position:341 position:278 position:276 position:275 position:229 position:226 position:210 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:139 position:119 position:110 position:105 target:15142 target_cate:764 label:0
history:1573 cate:540 position:0 target:18294 target_cate:463 label:1
history:9837 history:13438 history:13690 cate:351 cate:629 cate:24 position:287 position:287 position:287 target:26044 target_cate:351 label:0
history:1708 history:2675 history:4935 history:7401 history:14413 history:22177 history:30319 history:32217 history:34342 history:40235 history:42963 history:43949 history:54816 cate:463 cate:115 cate:474 cate:616 cate:474 cate:44 cate:113 cate:279 cate:164 cate:142 cate:616 cate:649 cate:36 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 position:6 target:31992 target_cate:115 label:0
history:8025 history:11769 history:36188 history:42006 cate:142 cate:262 cate:714 cate:142 position:1107 position:1107 position:21 position:13 target:8209 target_cate:142 label:0
history:30266 cate:176 position:0 target:44167 target_cate:692 label:0
history:13000 history:14769 history:2940 history:27638 history:23158 cate:765 cate:27 cate:736 cate:554 cate:112 position:1155 position:797 position:348 position:348 position:334 target:55050 target_cate:725 label:0
history:32557 history:18668 history:43441 cate:765 cate:707 cate:396 position:1 position:0 position:0 target:44217 target_cate:681 label:1
history:5665 history:5964 history:18874 cate:542 cate:746 cate:196 position:1229 position:1202 position:123 target:16747 target_cate:179 label:0
history:7014 history:29912 history:42468 cate:194 cate:612 cate:558 position:2424 position:0 position:0 target:20800 target_cate:355 label:0
history:8320 history:9743 history:1735 history:442 history:5216 history:11568 cate:234 cate:251 cate:241 cate:603 cate:476 cate:649 position:211 position:70 position:61 position:34 position:34 position:27 target:32738 target_cate:153 label:0
history:533 history:1447 cate:744 cate:744 position:664 position:337 target:17843 target_cate:744 label:1
history:48390 history:48191 cate:714 cate:714 position:137 position:92 target:48864 target_cate:708 label:1
history:9312 history:16166 history:12754 history:21433 history:28142 history:7486 cate:215 cate:674 cate:241 cate:115 cate:558 cate:241 position:1910 position:1045 position:414 position:371 position:371 position:347 target:38629 target_cate:48 label:1
history:10401 history:11665 history:10739 cate:142 cate:364 cate:766 position:363 position:217 position:48 target:5989 target_cate:463 label:0
history:10408 history:14363 history:8807 history:14947 history:24701 history:44676 history:40914 history:12241 history:14906 history:29247 history:32347 history:5834 history:18291 history:18313 history:23375 history:24075 history:7020 history:14307 history:15891 cate:140 cate:140 cate:749 cate:281 cate:444 cate:388 cate:504 cate:385 cate:196 cate:339 cate:746 cate:351 cate:463 cate:746 cate:197 cate:90 cate:746 cate:576 cate:476 position:1338 position:1336 position:1305 position:1267 position:835 position:88 position:87 position:86 position:86 position:86 position:86 position:84 position:84 position:84 position:84 position:84 position:83 position:83 position:83 target:37949 target_cate:330 label:1
history:50194 cate:444 position:243 target:15572 target_cate:216 label:0
history:24021 cate:281 position:718 target:25850 target_cate:140 label:1
history:22185 history:28726 history:55777 cate:142 cate:766 cate:351 position:923 position:923 position:133 target:17 target_cate:541 label:1
history:31776 history:34767 history:28854 history:34769 history:38022 history:38667 history:32917 history:9094 history:40879 history:41634 history:42252 history:19865 history:47983 history:38818 history:40131 history:40690 history:18915 history:48539 history:49619 history:18554 history:24836 cate:70 cate:239 cate:113 cate:48 cate:486 cate:541 cate:352 cate:197 cate:347 cate:385 cate:34 cate:476 cate:704 cate:388 cate:385 cate:281 cate:225 cate:474 cate:157 cate:706 cate:53 position:490 position:490 position:473 position:360 position:360 position:360 position:209 position:199 position:199 position:199 position:199 position:198 position:198 position:196 position:196 position:174 position:93 position:36 position:36 position:0 position:0 target:25602 target_cate:707 label:1
history:10544 history:15159 history:23606 history:33556 history:46886 history:55061 history:2079 history:27022 history:40345 history:43556 history:3807 history:28732 cate:642 cate:87 cate:641 cate:113 cate:558 cate:157 cate:564 cate:44 cate:194 cate:26 cate:54 cate:113 position:844 position:362 position:362 position:362 position:362 position:362 position:205 position:205 position:205 position:205 position:0 position:0 target:51293 target_cate:272 label:0
history:19005 history:41469 history:42368 history:5739 history:30169 history:32266 history:54743 history:56959 history:26271 cate:145 cate:482 cate:707 cate:790 cate:101 cate:347 cate:197 cate:368 cate:674 position:365 position:365 position:365 position:258 position:258 position:258 position:258 position:258 position:0 target:5602 target_cate:158 label:0
history:7166 history:16886 history:21083 history:7328 history:25545 cate:560 cate:213 cate:87 cate:744 cate:87 position:474 position:474 position:474 position:214 position:214 target:32494 target_cate:321 label:1
history:2306 cate:260 position:51 target:30286 target_cate:179 label:0
history:57709 history:55115 cate:351 cate:483 position:99 position:50 target:25035 target_cate:142 label:0
history:16641 history:35845 cate:153 cate:311 position:0 position:0 target:36985 target_cate:68 label:1
history:31144 history:4107 cate:189 cate:168 position:1179 position:0 target:50619 target_cate:142 label:0
history:36331 history:9873 history:10659 history:14382 history:21430 history:28164 cate:680 cate:197 cate:185 cate:11 cate:115 cate:476 position:278 position:0 position:0 position:0 position:0 position:0 target:37887 target_cate:484 label:1
history:19519 history:3748 history:33772 history:22436 history:38789 history:46337 cate:649 cate:351 cate:210 cate:115 cate:113 cate:115 position:1038 position:517 position:470 position:349 position:150 position:37 target:23980 target_cate:649 label:1
history:30789 history:37586 history:42354 history:26171 history:15017 history:28654 history:44960 cate:142 cate:714 cate:142 cate:483 cate:484 cate:474 cate:157 position:158 position:158 position:146 position:36 position:26 position:26 position:26 target:41552 target_cate:746 label:1
history:52662 cate:576 position:0 target:53627 target_cate:776 label:0
history:12258 history:15133 history:15681 history:5066 history:6420 history:13421 history:6577 history:29202 history:38939 cate:216 cate:558 cate:111 cate:570 cate:447 cate:5 cate:111 cate:281 cate:347 position:1544 position:1359 position:1312 position:743 position:743 position:636 position:560 position:103 position:24 target:7818 target_cate:558 label:0
history:610 history:1258 history:2332 history:7508 history:10814 history:10797 history:11710 cate:543 cate:611 cate:611 cate:653 cate:110 cate:201 cate:179 position:2452 position:1361 position:935 position:669 position:524 position:55 position:45 target:11495 target_cate:558 label:1
history:12584 history:2707 history:1664 history:25878 history:25949 cate:790 cate:694 cate:694 cate:142 cate:611 position:768 position:729 position:625 position:236 position:7 target:25286 target_cate:792 label:1
history:32423 history:24223 cate:135 cate:90 position:421 position:76 target:2323 target_cate:399 label:0
history:11959 cate:197 position:0 target:15349 target_cate:351 label:1
history:44448 history:58138 history:41930 history:57603 history:59009 history:61316 history:61559 history:599 cate:339 cate:629 cate:115 cate:388 cate:1 cate:142 cate:714 cate:297 position:320 position:97 position:23 position:23 position:23 position:23 position:23 position:0 target:54434 target_cate:142 label:0
history:43441 history:12617 history:47970 history:52144 cate:396 cate:196 cate:142 cate:629 position:213 position:208 position:208 position:208 target:29211 target_cate:351 label:1
history:25327 history:40258 cate:656 cate:398 position:676 position:3 target:40261 target_cate:142 label:1
history:4637 cate:474 position:62 target:59864 target_cate:687 label:0
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddlerec.core.utils import envs
from paddlerec.core.model import ModelBase
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
act="relu")
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
seed=dropout_seed,
is_test=False)
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
seed=dropout_seed,
is_test=False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def _init_hyper_parameters(self):
self.item_emb_size = envs.get_global_env(
"hyper_parameters.item_emb_size", 64)
self.cat_emb_size = envs.get_global_env(
"hyper_parameters.cat_emb_size", 64)
self.position_emb_size = envs.get_global_env(
"hyper_parameters.position_emb_size", 64)
self.act = envs.get_global_env("hyper_parameters.act", "sigmoid")
self.is_sparse = envs.get_global_env("hyper_parameters.is_sparse",
False)
# significant for speeding up the training process
self.use_DataLoader = envs.get_global_env(
"hyper_parameters.use_DataLoader", False)
self.item_count = envs.get_global_env("hyper_parameters.item_count",
63001)
self.cat_count = envs.get_global_env("hyper_parameters.cat_count", 801)
self.position_count = envs.get_global_env(
"hyper_parameters.position_count", 5001)
self.n_encoder_layers = envs.get_global_env(
"hyper_parameters.n_encoder_layers", 1)
self.d_model = envs.get_global_env("hyper_parameters.d_model", 96)
self.d_key = envs.get_global_env("hyper_parameters.d_key", None)
self.d_value = envs.get_global_env("hyper_parameters.d_value", None)
self.n_head = envs.get_global_env("hyper_parameters.n_head", None)
self.dropout_rate = envs.get_global_env(
"hyper_parameters.dropout_rate", 0.0)
self.postprocess_cmd = envs.get_global_env(
"hyper_parameters.postprocess_cmd", "da")
self.preprocess_cmd = envs.get_global_env(
"hyper_parameters.postprocess_cmd", "n")
self.prepostprocess_dropout = envs.get_global_env(
"hyper_parameters.prepostprocess_dropout", 0.0)
self.d_inner_hid = envs.get_global_env("hyper_parameters.d_inner_hid",
512)
self.relu_dropout = envs.get_global_env(
"hyper_parameters.relu_dropout", 0.0)
self.layer_sizes = envs.get_global_env("hyper_parameters.fc_sizes",
None)
def multi_head_attention(self, queries, keys, values, d_key, d_value,
d_model, n_head, dropout_rate):
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3
):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = fluid.layers.fc(input=queries,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
k = fluid.layers.fc(input=keys,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
v = fluid.layers.fc(input=values,
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
return q, k, v
def __split_heads_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Reshape input tensors at the last dimension to split multi-heads
and then transpose. Specifically, transform the input tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] to the output tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped_q = fluid.layers.reshape(
x=queries, shape=[0, 0, n_head, d_key], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
q = fluid.layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3])
# For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search.
reshaped_k = fluid.layers.reshape(
x=keys, shape=[0, 0, n_head, d_key], inplace=True)
k = fluid.layers.transpose(x=reshaped_k, perm=[0, 2, 1, 3])
reshaped_v = fluid.layers.reshape(
x=values, shape=[0, 0, n_head, d_value], inplace=True)
v = fluid.layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3])
return q, k, v
def scaled_dot_product_attention(q, k, v, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
product = fluid.layers.matmul(
x=q, y=k, transpose_y=True, alpha=d_key**-0.5)
weights = fluid.layers.softmax(product)
if dropout_rate:
weights = fluid.layers.dropout(
weights,
dropout_prob=dropout_rate,
seed=None,
is_test=False)
out = fluid.layers.matmul(weights, v)
return out
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = fluid.layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return fluid.layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value)
ctx_multiheads = scaled_dot_product_attention(q, k, v, d_model,
dropout_rate)
out = __combine_heads(ctx_multiheads)
proj_out = fluid.layers.fc(input=out,
size=d_model,
bias_attr=False,
num_flatten_dims=2)
return proj_out
def encoder_layer(self, x):
attention_out = self.multi_head_attention(
pre_process_layer(x, self.preprocess_cmd,
self.prepostprocess_dropout), None, None,
self.d_key, self.d_value, self.d_model, self.n_head,
self.dropout_rate)
attn_output = post_process_layer(x, attention_out,
self.postprocess_cmd,
self.prepostprocess_dropout)
ffd_output = positionwise_feed_forward(
pre_process_layer(attn_output, self.preprocess_cmd,
self.prepostprocess_dropout), self.d_inner_hid,
self.d_model, self.relu_dropout)
return post_process_layer(attn_output, ffd_output,
self.postprocess_cmd,
self.prepostprocess_dropout)
def net(self, inputs, is_infer=False):
init_value_ = 0.1
hist_item_seq = self._sparse_data_var[1]
hist_cat_seq = self._sparse_data_var[2]
position_seq = self._sparse_data_var[3]
target_item = self._sparse_data_var[4]
target_cat = self._sparse_data_var[5]
target_position = self._sparse_data_var[6]
self.label = self._sparse_data_var[0]
item_emb_attr = fluid.ParamAttr(name="item_emb")
cat_emb_attr = fluid.ParamAttr(name="cat_emb")
position_emb_attr = fluid.ParamAttr(name="position_emb")
hist_item_emb = fluid.embedding(
input=hist_item_seq,
size=[self.item_count, self.item_emb_size],
param_attr=item_emb_attr,
is_sparse=self.is_sparse)
hist_cat_emb = fluid.embedding(
input=hist_cat_seq,
size=[self.cat_count, self.cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=self.is_sparse)
hist_position_emb = fluid.embedding(
input=hist_cat_seq,
size=[self.position_count, self.position_emb_size],
param_attr=position_emb_attr,
is_sparse=self.is_sparse)
target_item_emb = fluid.embedding(
input=target_item,
size=[self.item_count, self.item_emb_size],
param_attr=item_emb_attr,
is_sparse=self.is_sparse)
target_cat_emb = fluid.embedding(
input=target_cat,
size=[self.cat_count, self.cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=self.is_sparse)
target_position_emb = fluid.embedding(
input=target_position,
size=[self.position_count, self.position_emb_size],
param_attr=position_emb_attr,
is_sparse=self.is_sparse)
item_sequence_target = fluid.layers.reduce_sum(
fluid.layers.sequence_concat([hist_item_emb, target_item_emb]),
dim=1)
cat_sequence_target = fluid.layers.reduce_sum(
fluid.layers.sequence_concat([hist_cat_emb, target_cat_emb]),
dim=1)
position_sequence_target = fluid.layers.reduce_sum(
fluid.layers.sequence_concat(
[hist_position_emb, target_position_emb]),
dim=1)
whole_embedding_withlod = fluid.layers.concat(
[
item_sequence_target, cat_sequence_target,
position_sequence_target
],
axis=1)
pad_value = fluid.layers.assign(input=np.array(
[0.0], dtype=np.float32))
whole_embedding, _ = fluid.layers.sequence_pad(whole_embedding_withlod,
pad_value)
for _ in range(self.n_encoder_layers):
enc_output = self.encoder_layer(whole_embedding)
enc_input = enc_output
enc_output = pre_process_layer(enc_output, self.preprocess_cmd,
self.prepostprocess_dropout)
dnn_input = fluid.layers.reduce_sum(enc_output, dim=1)
for s in self.layer_sizes:
dnn_input = fluid.layers.fc(
input=dnn_input,
size=s,
act=self.act,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=init_value_ / math.sqrt(float(10)))),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=init_value_)))
y_dnn = fluid.layers.fc(input=dnn_input, size=1, act=None)
self.predict = fluid.layers.sigmoid(y_dnn)
cost = fluid.layers.log_loss(
input=self.predict, label=fluid.layers.cast(self.label, "float32"))
avg_cost = fluid.layers.reduce_sum(cost)
self._cost = avg_cost
predict_2d = fluid.layers.concat([1 - self.predict, self.predict], 1)
label_int = fluid.layers.cast(self.label, 'int64')
auc_var, batch_auc_var, _ = fluid.layers.auc(input=predict_2d,
label=label_int,
slide_steps=0)
self._metrics["AUC"] = auc_var
self._metrics["BATCH_AUC"] = batch_auc_var
if is_infer:
self._infer_results["AUC"] = auc_var
......@@ -37,6 +37,7 @@
| xDeepFM | xDeepFM | [xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://dl.acm.org/doi/pdf/10.1145/3219819.3220023)(2018) |
| DIN | Deep Interest Network | [Deep Interest Network for Click-Through Rate Prediction](https://dl.acm.org/doi/pdf/10.1145/3219819.3219823)(2018) |
| DIEN | Deep Interest Evolution Network | [Deep Interest Evolution Network for Click-Through Rate Prediction](https://www.aaai.org/ojs/index.php/AAAI/article/view/4545/4423)(2019) |
| BST | transformer in user behavior sequence for rank | [Behavior Sequence Transformer for E-commerce Recommendation in Alibaba](https://arxiv.org/pdf/1905.06874v1.pdf)(2019) |
| FGCNN | Feature Generation by CNN | [Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1904.04447.pdf)(2019) |
| FIBINET | Combining Feature Importance and Bilinear feature Interaction | [《FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction》]( https://arxiv.org/pdf/1905.09433.pdf)(2019) |
| FLEN | Leveraging Field for Scalable CTR Prediction | [《FLEN: Leveraging Field for Scalable CTR Prediction》]( https://arxiv.org/pdf/1911.04690.pdf)(2019) |
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册