提交 c718dc5d 编写于 作者: M malin10

update gru4rec

上级 e9543dc8
...@@ -17,17 +17,18 @@ workspace: "paddlerec.models.recall.gru4rec" ...@@ -17,17 +17,18 @@ workspace: "paddlerec.models.recall.gru4rec"
dataset: dataset:
- name: dataset_train - name: dataset_train
batch_size: 5 batch_size: 5
type: QueueDataset type: DataLoader # QueueDataset
data_path: "{workspace}/data/train" data_path: "{workspace}/data/train"
data_converter: "{workspace}/rsc15_reader.py" data_converter: "{workspace}/rsc15_reader.py"
- name: dataset_infer - name: dataset_infer
batch_size: 5 batch_size: 5
type: QueueDataset type: DataLoader #QueueDataset
data_path: "{workspace}/data/test" data_path: "{workspace}/data/test"
data_converter: "{workspace}/rsc15_reader.py" data_converter: "{workspace}/rsc15_reader.py"
hyper_parameters: hyper_parameters:
vocab_size: 1000 recall_k: 20
vocab_size: 37483
hid_size: 100 hid_size: 100
emb_lr_x: 10.0 emb_lr_x: 10.0
gru_lr_x: 1.0 gru_lr_x: 1.0
...@@ -47,15 +48,15 @@ runner: ...@@ -47,15 +48,15 @@ runner:
- name: train_runner - name: train_runner
class: train class: train
device: cpu device: cpu
epochs: 3 epochs: 10
save_checkpoint_interval: 2 save_checkpoint_interval: 2
save_inference_interval: 4 save_inference_interval: 4
save_checkpoint_path: "increment" save_checkpoint_path: "increment_gru4rec"
save_inference_path: "inference" save_inference_path: "inference_gru4rec"
print_interval: 10 print_interval: 10
- name: infer_runner - name: infer_runner
class: infer class: infer
init_model_path: "increment/0" init_model_path: "increment_gru4rec"
device: cpu device: cpu
phase: phase:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import codecs
def convert_format(input, output):
with codecs.open(input, "r", encoding='utf-8') as rf:
with codecs.open(output, "w", encoding='utf-8') as wf:
last_sess = -1
sign = 1
i = 0
for l in rf:
i = i + 1
if i == 1:
continue
if (i % 1000000 == 1):
print(i)
tokens = l.strip().split()
if (int(tokens[0]) != last_sess):
if (sign):
sign = 0
wf.write(tokens[1] + " ")
else:
wf.write("\n" + tokens[1] + " ")
last_sess = int(tokens[0])
else:
wf.write(tokens[1] + " ")
input = "rsc15_train_tr.txt"
output = "rsc15_train_tr_paddle.txt"
input2 = "rsc15_test.txt"
output2 = "rsc15_test_paddle.txt"
convert_format(input, output)
convert_format(input2, output2)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
import sys
import time
import os
lasttime = time.time()
FLUSH_INTERVAL = 0.1
def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()
def _download_file(url, savepath, print_progress):
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * dl) / total_length))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
_download_file("https://paddlerec.bj.bcebos.com/gnn%2Fyoochoose-clicks.dat",
"./yoochoose-clicks.dat", True)
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 25 16:20:12 2015
@author: Balázs Hidasi
"""
import numpy as np
import pandas as pd
import datetime as dt
import time
PATH_TO_ORIGINAL_DATA = './'
PATH_TO_PROCESSED_DATA = './'
data = pd.read_csv(
PATH_TO_ORIGINAL_DATA + 'yoochoose-clicks.dat',
sep=',',
header=0,
usecols=[0, 1, 2],
dtype={0: np.int32,
1: str,
2: np.int64})
data.columns = ['SessionId', 'TimeStr', 'ItemId']
data['Time'] = data.TimeStr.apply(lambda x: time.mktime(dt.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ').timetuple())) #This is not UTC. It does not really matter.
del (data['TimeStr'])
session_lengths = data.groupby('SessionId').size()
data = data[np.in1d(data.SessionId, session_lengths[session_lengths > 1]
.index)]
item_supports = data.groupby('ItemId').size()
data = data[np.in1d(data.ItemId, item_supports[item_supports >= 5].index)]
session_lengths = data.groupby('SessionId').size()
data = data[np.in1d(data.SessionId, session_lengths[session_lengths >= 2]
.index)]
tmax = data.Time.max()
session_max_times = data.groupby('SessionId').Time.max()
session_train = session_max_times[session_max_times < tmax - 86400].index
session_test = session_max_times[session_max_times >= tmax - 86400].index
train = data[np.in1d(data.SessionId, session_train)]
test = data[np.in1d(data.SessionId, session_test)]
test = test[np.in1d(test.ItemId, train.ItemId)]
tslength = test.groupby('SessionId').size()
test = test[np.in1d(test.SessionId, tslength[tslength >= 2].index)]
print('Full train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
len(train), train.SessionId.nunique(), train.ItemId.nunique()))
train.to_csv(
PATH_TO_PROCESSED_DATA + 'rsc15_train_full.txt', sep='\t', index=False)
print('Test set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
len(test), test.SessionId.nunique(), test.ItemId.nunique()))
test.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_test.txt', sep='\t', index=False)
tmax = train.Time.max()
session_max_times = train.groupby('SessionId').Time.max()
session_train = session_max_times[session_max_times < tmax - 86400].index
session_valid = session_max_times[session_max_times >= tmax - 86400].index
train_tr = train[np.in1d(train.SessionId, session_train)]
valid = train[np.in1d(train.SessionId, session_valid)]
valid = valid[np.in1d(valid.ItemId, train_tr.ItemId)]
tslength = valid.groupby('SessionId').size()
valid = valid[np.in1d(valid.SessionId, tslength[tslength >= 2].index)]
print('Train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
len(train_tr), train_tr.SessionId.nunique(), train_tr.ItemId.nunique()))
train_tr.to_csv(
PATH_TO_PROCESSED_DATA + 'rsc15_train_tr.txt', sep='\t', index=False)
print('Validation set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
len(valid), valid.SessionId.nunique(), valid.ItemId.nunique()))
valid.to_csv(
PATH_TO_PROCESSED_DATA + 'rsc15_train_valid.txt', sep='\t', index=False)
#! /bin/bash
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
dataset=$1
src=$1
if [[ $src == "yoochoose1_4" || $src == "yoochoose1_64" ]];then
src="yoochoose"
elif [[ $src == "diginetica" ]];then
src="diginetica"
else
echo "Usage: sh data_prepare.sh [diginetica|yoochoose1_4|yoochoose1_64]"
exit 1
fi
echo "begin to download data"
cd data && python download.py $src
mkdir $dataset
python preprocess.py --dataset $src
echo "begin to convert data (binary -> txt)"
python convert_data.py --data_dir $dataset
cat ${dataset}/train.txt | wc -l >> config.txt
rm -rf train && mkdir train
mv ${dataset}/train.txt train
rm -rf test && mkdir test
mv ${dataset}/test.txt test
...@@ -16,6 +16,7 @@ import paddle.fluid as fluid ...@@ -16,6 +16,7 @@ import paddle.fluid as fluid
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
from paddlerec.core.model import ModelBase from paddlerec.core.model import ModelBase
from paddlerec.core.metrics import Precision
class Model(ModelBase): class Model(ModelBase):
...@@ -81,13 +82,15 @@ class Model(ModelBase): ...@@ -81,13 +82,15 @@ class Model(ModelBase):
high=self.init_high_bound), high=self.init_high_bound),
learning_rate=self.fc_lr_x)) learning_rate=self.fc_lr_x))
cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq) cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq)
acc = fluid.layers.accuracy( # acc = fluid.layers.accuracy(
input=fc, label=dst_wordseq, k=self.recall_k) # input=fc, label=dst_wordseq, k=self.recall_k)
acc = Precision(input=fc, label=dst_wordseq, k=self.recall_k)
if is_infer: if is_infer:
self._infer_results['recall20'] = acc self._infer_results['P@20'] = acc
return return
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
self._cost = avg_cost self._cost = avg_cost
self._metrics["cost"] = avg_cost self._metrics["cost"] = avg_cost
self._metrics["acc"] = acc self._metrics["P@20"] = acc
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册