From c718dc5d4b77265700ccb3a209459d5d1f659c79 Mon Sep 17 00:00:00 2001 From: malin10 Date: Tue, 21 Jul 2020 21:40:05 +0800 Subject: [PATCH] update gru4rec --- models/recall/gru4rec/config.yaml | 15 +++-- models/recall/gru4rec/data/convert_format.py | 48 ++++++++++++++ models/recall/gru4rec/data/download.py | 61 +++++++++++++++++ models/recall/gru4rec/data/preprocess.py | 70 ++++++++++++++++++++ models/recall/gru4rec/data_prepare.sh | 45 +++++++++++++ models/recall/gru4rec/model.py | 11 +-- 6 files changed, 239 insertions(+), 11 deletions(-) create mode 100644 models/recall/gru4rec/data/convert_format.py create mode 100644 models/recall/gru4rec/data/download.py create mode 100644 models/recall/gru4rec/data/preprocess.py create mode 100644 models/recall/gru4rec/data_prepare.sh diff --git a/models/recall/gru4rec/config.yaml b/models/recall/gru4rec/config.yaml index 98250ae0..b0f8073e 100644 --- a/models/recall/gru4rec/config.yaml +++ b/models/recall/gru4rec/config.yaml @@ -17,17 +17,18 @@ workspace: "paddlerec.models.recall.gru4rec" dataset: - name: dataset_train batch_size: 5 - type: QueueDataset + type: DataLoader # QueueDataset data_path: "{workspace}/data/train" data_converter: "{workspace}/rsc15_reader.py" - name: dataset_infer batch_size: 5 - type: QueueDataset + type: DataLoader #QueueDataset data_path: "{workspace}/data/test" data_converter: "{workspace}/rsc15_reader.py" hyper_parameters: - vocab_size: 1000 + recall_k: 20 + vocab_size: 37483 hid_size: 100 emb_lr_x: 10.0 gru_lr_x: 1.0 @@ -47,15 +48,15 @@ runner: - name: train_runner class: train device: cpu - epochs: 3 + epochs: 10 save_checkpoint_interval: 2 save_inference_interval: 4 - save_checkpoint_path: "increment" - save_inference_path: "inference" + save_checkpoint_path: "increment_gru4rec" + save_inference_path: "inference_gru4rec" print_interval: 10 - name: infer_runner class: infer - init_model_path: "increment/0" + init_model_path: "increment_gru4rec" device: cpu phase: diff --git a/models/recall/gru4rec/data/convert_format.py b/models/recall/gru4rec/data/convert_format.py new file mode 100644 index 00000000..9a6867e6 --- /dev/null +++ b/models/recall/gru4rec/data/convert_format.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import codecs + + +def convert_format(input, output): + with codecs.open(input, "r", encoding='utf-8') as rf: + with codecs.open(output, "w", encoding='utf-8') as wf: + last_sess = -1 + sign = 1 + i = 0 + for l in rf: + i = i + 1 + if i == 1: + continue + if (i % 1000000 == 1): + print(i) + tokens = l.strip().split() + if (int(tokens[0]) != last_sess): + if (sign): + sign = 0 + wf.write(tokens[1] + " ") + else: + wf.write("\n" + tokens[1] + " ") + last_sess = int(tokens[0]) + else: + wf.write(tokens[1] + " ") + + +input = "rsc15_train_tr.txt" +output = "rsc15_train_tr_paddle.txt" +input2 = "rsc15_test.txt" +output2 = "rsc15_test_paddle.txt" +convert_format(input, output) +convert_format(input2, output2) diff --git a/models/recall/gru4rec/data/download.py b/models/recall/gru4rec/data/download.py new file mode 100644 index 00000000..b0e0979e --- /dev/null +++ b/models/recall/gru4rec/data/download.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +import sys +import time +import os + +lasttime = time.time() +FLUSH_INTERVAL = 0.1 + + +def progress(str, end=False): + global lasttime + if end: + str += "\n" + lasttime = 0 + if time.time() - lasttime >= FLUSH_INTERVAL: + sys.stdout.write("\r%s" % str) + lasttime = time.time() + sys.stdout.flush() + + +def _download_file(url, savepath, print_progress): + r = requests.get(url, stream=True) + total_length = r.headers.get('content-length') + + if total_length is None: + with open(savepath, 'wb') as f: + shutil.copyfileobj(r.raw, f) + else: + with open(savepath, 'wb') as f: + dl = 0 + total_length = int(total_length) + starttime = time.time() + if print_progress: + print("Downloading %s" % os.path.basename(savepath)) + for data in r.iter_content(chunk_size=4096): + dl += len(data) + f.write(data) + if print_progress: + done = int(50 * dl / total_length) + progress("[%-50s] %.2f%%" % + ('=' * done, float(100 * dl) / total_length)) + if print_progress: + progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) + + +_download_file("https://paddlerec.bj.bcebos.com/gnn%2Fyoochoose-clicks.dat", + "./yoochoose-clicks.dat", True) diff --git a/models/recall/gru4rec/data/preprocess.py b/models/recall/gru4rec/data/preprocess.py new file mode 100644 index 00000000..66ed72b6 --- /dev/null +++ b/models/recall/gru4rec/data/preprocess.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Jun 25 16:20:12 2015 + +@author: Balázs Hidasi +""" + +import numpy as np +import pandas as pd +import datetime as dt +import time + +PATH_TO_ORIGINAL_DATA = './' +PATH_TO_PROCESSED_DATA = './' + +data = pd.read_csv( + PATH_TO_ORIGINAL_DATA + 'yoochoose-clicks.dat', + sep=',', + header=0, + usecols=[0, 1, 2], + dtype={0: np.int32, + 1: str, + 2: np.int64}) +data.columns = ['SessionId', 'TimeStr', 'ItemId'] +data['Time'] = data.TimeStr.apply(lambda x: time.mktime(dt.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ').timetuple())) #This is not UTC. It does not really matter. +del (data['TimeStr']) + +session_lengths = data.groupby('SessionId').size() +data = data[np.in1d(data.SessionId, session_lengths[session_lengths > 1] + .index)] +item_supports = data.groupby('ItemId').size() +data = data[np.in1d(data.ItemId, item_supports[item_supports >= 5].index)] +session_lengths = data.groupby('SessionId').size() +data = data[np.in1d(data.SessionId, session_lengths[session_lengths >= 2] + .index)] + +tmax = data.Time.max() +session_max_times = data.groupby('SessionId').Time.max() +session_train = session_max_times[session_max_times < tmax - 86400].index +session_test = session_max_times[session_max_times >= tmax - 86400].index +train = data[np.in1d(data.SessionId, session_train)] +test = data[np.in1d(data.SessionId, session_test)] +test = test[np.in1d(test.ItemId, train.ItemId)] +tslength = test.groupby('SessionId').size() +test = test[np.in1d(test.SessionId, tslength[tslength >= 2].index)] +print('Full train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format( + len(train), train.SessionId.nunique(), train.ItemId.nunique())) +train.to_csv( + PATH_TO_PROCESSED_DATA + 'rsc15_train_full.txt', sep='\t', index=False) +print('Test set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format( + len(test), test.SessionId.nunique(), test.ItemId.nunique())) +test.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_test.txt', sep='\t', index=False) + +tmax = train.Time.max() +session_max_times = train.groupby('SessionId').Time.max() +session_train = session_max_times[session_max_times < tmax - 86400].index +session_valid = session_max_times[session_max_times >= tmax - 86400].index +train_tr = train[np.in1d(train.SessionId, session_train)] +valid = train[np.in1d(train.SessionId, session_valid)] +valid = valid[np.in1d(valid.ItemId, train_tr.ItemId)] +tslength = valid.groupby('SessionId').size() +valid = valid[np.in1d(valid.SessionId, tslength[tslength >= 2].index)] +print('Train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format( + len(train_tr), train_tr.SessionId.nunique(), train_tr.ItemId.nunique())) +train_tr.to_csv( + PATH_TO_PROCESSED_DATA + 'rsc15_train_tr.txt', sep='\t', index=False) +print('Validation set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format( + len(valid), valid.SessionId.nunique(), valid.ItemId.nunique())) +valid.to_csv( + PATH_TO_PROCESSED_DATA + 'rsc15_train_valid.txt', sep='\t', index=False) diff --git a/models/recall/gru4rec/data_prepare.sh b/models/recall/gru4rec/data_prepare.sh new file mode 100644 index 00000000..a97e57ab --- /dev/null +++ b/models/recall/gru4rec/data_prepare.sh @@ -0,0 +1,45 @@ +#! /bin/bash + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +dataset=$1 +src=$1 + +if [[ $src == "yoochoose1_4" || $src == "yoochoose1_64" ]];then + src="yoochoose" +elif [[ $src == "diginetica" ]];then + src="diginetica" +else + echo "Usage: sh data_prepare.sh [diginetica|yoochoose1_4|yoochoose1_64]" + exit 1 +fi + +echo "begin to download data" +cd data && python download.py $src +mkdir $dataset +python preprocess.py --dataset $src + +echo "begin to convert data (binary -> txt)" +python convert_data.py --data_dir $dataset + +cat ${dataset}/train.txt | wc -l >> config.txt + +rm -rf train && mkdir train +mv ${dataset}/train.txt train + +rm -rf test && mkdir test +mv ${dataset}/test.txt test diff --git a/models/recall/gru4rec/model.py b/models/recall/gru4rec/model.py index be12ad0f..27e47f1c 100644 --- a/models/recall/gru4rec/model.py +++ b/models/recall/gru4rec/model.py @@ -16,6 +16,7 @@ import paddle.fluid as fluid from paddlerec.core.utils import envs from paddlerec.core.model import ModelBase +from paddlerec.core.metrics import Precision class Model(ModelBase): @@ -81,13 +82,15 @@ class Model(ModelBase): high=self.init_high_bound), learning_rate=self.fc_lr_x)) cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq) - acc = fluid.layers.accuracy( - input=fc, label=dst_wordseq, k=self.recall_k) + # acc = fluid.layers.accuracy( + # input=fc, label=dst_wordseq, k=self.recall_k) + acc = Precision(input=fc, label=dst_wordseq, k=self.recall_k) + if is_infer: - self._infer_results['recall20'] = acc + self._infer_results['P@20'] = acc return avg_cost = fluid.layers.mean(x=cost) self._cost = avg_cost self._metrics["cost"] = avg_cost - self._metrics["acc"] = acc + self._metrics["P@20"] = acc -- GitLab