Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
c718dc5d
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c718dc5d
编写于
7月 21, 2020
作者:
M
malin10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update gru4rec
上级
e9543dc8
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
239 addition
and
11 deletion
+239
-11
models/recall/gru4rec/config.yaml
models/recall/gru4rec/config.yaml
+8
-7
models/recall/gru4rec/data/convert_format.py
models/recall/gru4rec/data/convert_format.py
+48
-0
models/recall/gru4rec/data/download.py
models/recall/gru4rec/data/download.py
+61
-0
models/recall/gru4rec/data/preprocess.py
models/recall/gru4rec/data/preprocess.py
+70
-0
models/recall/gru4rec/data_prepare.sh
models/recall/gru4rec/data_prepare.sh
+45
-0
models/recall/gru4rec/model.py
models/recall/gru4rec/model.py
+7
-4
未找到文件。
models/recall/gru4rec/config.yaml
浏览文件 @
c718dc5d
...
...
@@ -17,17 +17,18 @@ workspace: "paddlerec.models.recall.gru4rec"
dataset
:
-
name
:
dataset_train
batch_size
:
5
type
:
QueueDataset
type
:
DataLoader
#
QueueDataset
data_path
:
"
{workspace}/data/train"
data_converter
:
"
{workspace}/rsc15_reader.py"
-
name
:
dataset_infer
batch_size
:
5
type
:
QueueDataset
type
:
DataLoader
#
QueueDataset
data_path
:
"
{workspace}/data/test"
data_converter
:
"
{workspace}/rsc15_reader.py"
hyper_parameters
:
vocab_size
:
1000
recall_k
:
20
vocab_size
:
37483
hid_size
:
100
emb_lr_x
:
10.0
gru_lr_x
:
1.0
...
...
@@ -47,15 +48,15 @@ runner:
-
name
:
train_runner
class
:
train
device
:
cpu
epochs
:
3
epochs
:
10
save_checkpoint_interval
:
2
save_inference_interval
:
4
save_checkpoint_path
:
"
increment"
save_inference_path
:
"
inference"
save_checkpoint_path
:
"
increment
_gru4rec
"
save_inference_path
:
"
inference
_gru4rec
"
print_interval
:
10
-
name
:
infer_runner
class
:
infer
init_model_path
:
"
increment
/0
"
init_model_path
:
"
increment
_gru4rec
"
device
:
cpu
phase
:
...
...
models/recall/gru4rec/data/convert_format.py
0 → 100644
浏览文件 @
c718dc5d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
codecs
def
convert_format
(
input
,
output
):
with
codecs
.
open
(
input
,
"r"
,
encoding
=
'utf-8'
)
as
rf
:
with
codecs
.
open
(
output
,
"w"
,
encoding
=
'utf-8'
)
as
wf
:
last_sess
=
-
1
sign
=
1
i
=
0
for
l
in
rf
:
i
=
i
+
1
if
i
==
1
:
continue
if
(
i
%
1000000
==
1
):
print
(
i
)
tokens
=
l
.
strip
().
split
()
if
(
int
(
tokens
[
0
])
!=
last_sess
):
if
(
sign
):
sign
=
0
wf
.
write
(
tokens
[
1
]
+
" "
)
else
:
wf
.
write
(
"
\n
"
+
tokens
[
1
]
+
" "
)
last_sess
=
int
(
tokens
[
0
])
else
:
wf
.
write
(
tokens
[
1
]
+
" "
)
input
=
"rsc15_train_tr.txt"
output
=
"rsc15_train_tr_paddle.txt"
input2
=
"rsc15_test.txt"
output2
=
"rsc15_test_paddle.txt"
convert_format
(
input
,
output
)
convert_format
(
input2
,
output2
)
models/recall/gru4rec/data/download.py
0 → 100644
浏览文件 @
c718dc5d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
requests
import
sys
import
time
import
os
lasttime
=
time
.
time
()
FLUSH_INTERVAL
=
0.1
def
progress
(
str
,
end
=
False
):
global
lasttime
if
end
:
str
+=
"
\n
"
lasttime
=
0
if
time
.
time
()
-
lasttime
>=
FLUSH_INTERVAL
:
sys
.
stdout
.
write
(
"
\r
%s"
%
str
)
lasttime
=
time
.
time
()
sys
.
stdout
.
flush
()
def
_download_file
(
url
,
savepath
,
print_progress
):
r
=
requests
.
get
(
url
,
stream
=
True
)
total_length
=
r
.
headers
.
get
(
'content-length'
)
if
total_length
is
None
:
with
open
(
savepath
,
'wb'
)
as
f
:
shutil
.
copyfileobj
(
r
.
raw
,
f
)
else
:
with
open
(
savepath
,
'wb'
)
as
f
:
dl
=
0
total_length
=
int
(
total_length
)
starttime
=
time
.
time
()
if
print_progress
:
print
(
"Downloading %s"
%
os
.
path
.
basename
(
savepath
))
for
data
in
r
.
iter_content
(
chunk_size
=
4096
):
dl
+=
len
(
data
)
f
.
write
(
data
)
if
print_progress
:
done
=
int
(
50
*
dl
/
total_length
)
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
done
,
float
(
100
*
dl
)
/
total_length
))
if
print_progress
:
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
50
,
100
),
end
=
True
)
_download_file
(
"https://paddlerec.bj.bcebos.com/gnn%2Fyoochoose-clicks.dat"
,
"./yoochoose-clicks.dat"
,
True
)
models/recall/gru4rec/data/preprocess.py
0 → 100644
浏览文件 @
c718dc5d
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 25 16:20:12 2015
@author: Balázs Hidasi
"""
import
numpy
as
np
import
pandas
as
pd
import
datetime
as
dt
import
time
PATH_TO_ORIGINAL_DATA
=
'./'
PATH_TO_PROCESSED_DATA
=
'./'
data
=
pd
.
read_csv
(
PATH_TO_ORIGINAL_DATA
+
'yoochoose-clicks.dat'
,
sep
=
','
,
header
=
0
,
usecols
=
[
0
,
1
,
2
],
dtype
=
{
0
:
np
.
int32
,
1
:
str
,
2
:
np
.
int64
})
data
.
columns
=
[
'SessionId'
,
'TimeStr'
,
'ItemId'
]
data
[
'Time'
]
=
data
.
TimeStr
.
apply
(
lambda
x
:
time
.
mktime
(
dt
.
datetime
.
strptime
(
x
,
'%Y-%m-%dT%H:%M:%S.%fZ'
).
timetuple
()))
#This is not UTC. It does not really matter.
del
(
data
[
'TimeStr'
])
session_lengths
=
data
.
groupby
(
'SessionId'
).
size
()
data
=
data
[
np
.
in1d
(
data
.
SessionId
,
session_lengths
[
session_lengths
>
1
]
.
index
)]
item_supports
=
data
.
groupby
(
'ItemId'
).
size
()
data
=
data
[
np
.
in1d
(
data
.
ItemId
,
item_supports
[
item_supports
>=
5
].
index
)]
session_lengths
=
data
.
groupby
(
'SessionId'
).
size
()
data
=
data
[
np
.
in1d
(
data
.
SessionId
,
session_lengths
[
session_lengths
>=
2
]
.
index
)]
tmax
=
data
.
Time
.
max
()
session_max_times
=
data
.
groupby
(
'SessionId'
).
Time
.
max
()
session_train
=
session_max_times
[
session_max_times
<
tmax
-
86400
].
index
session_test
=
session_max_times
[
session_max_times
>=
tmax
-
86400
].
index
train
=
data
[
np
.
in1d
(
data
.
SessionId
,
session_train
)]
test
=
data
[
np
.
in1d
(
data
.
SessionId
,
session_test
)]
test
=
test
[
np
.
in1d
(
test
.
ItemId
,
train
.
ItemId
)]
tslength
=
test
.
groupby
(
'SessionId'
).
size
()
test
=
test
[
np
.
in1d
(
test
.
SessionId
,
tslength
[
tslength
>=
2
].
index
)]
print
(
'Full train set
\n\t
Events: {}
\n\t
Sessions: {}
\n\t
Items: {}'
.
format
(
len
(
train
),
train
.
SessionId
.
nunique
(),
train
.
ItemId
.
nunique
()))
train
.
to_csv
(
PATH_TO_PROCESSED_DATA
+
'rsc15_train_full.txt'
,
sep
=
'
\t
'
,
index
=
False
)
print
(
'Test set
\n\t
Events: {}
\n\t
Sessions: {}
\n\t
Items: {}'
.
format
(
len
(
test
),
test
.
SessionId
.
nunique
(),
test
.
ItemId
.
nunique
()))
test
.
to_csv
(
PATH_TO_PROCESSED_DATA
+
'rsc15_test.txt'
,
sep
=
'
\t
'
,
index
=
False
)
tmax
=
train
.
Time
.
max
()
session_max_times
=
train
.
groupby
(
'SessionId'
).
Time
.
max
()
session_train
=
session_max_times
[
session_max_times
<
tmax
-
86400
].
index
session_valid
=
session_max_times
[
session_max_times
>=
tmax
-
86400
].
index
train_tr
=
train
[
np
.
in1d
(
train
.
SessionId
,
session_train
)]
valid
=
train
[
np
.
in1d
(
train
.
SessionId
,
session_valid
)]
valid
=
valid
[
np
.
in1d
(
valid
.
ItemId
,
train_tr
.
ItemId
)]
tslength
=
valid
.
groupby
(
'SessionId'
).
size
()
valid
=
valid
[
np
.
in1d
(
valid
.
SessionId
,
tslength
[
tslength
>=
2
].
index
)]
print
(
'Train set
\n\t
Events: {}
\n\t
Sessions: {}
\n\t
Items: {}'
.
format
(
len
(
train_tr
),
train_tr
.
SessionId
.
nunique
(),
train_tr
.
ItemId
.
nunique
()))
train_tr
.
to_csv
(
PATH_TO_PROCESSED_DATA
+
'rsc15_train_tr.txt'
,
sep
=
'
\t
'
,
index
=
False
)
print
(
'Validation set
\n\t
Events: {}
\n\t
Sessions: {}
\n\t
Items: {}'
.
format
(
len
(
valid
),
valid
.
SessionId
.
nunique
(),
valid
.
ItemId
.
nunique
()))
valid
.
to_csv
(
PATH_TO_PROCESSED_DATA
+
'rsc15_train_valid.txt'
,
sep
=
'
\t
'
,
index
=
False
)
models/recall/gru4rec/data_prepare.sh
0 → 100644
浏览文件 @
c718dc5d
#! /bin/bash
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set
-e
dataset
=
$1
src
=
$1
if
[[
$src
==
"yoochoose1_4"
||
$src
==
"yoochoose1_64"
]]
;
then
src
=
"yoochoose"
elif
[[
$src
==
"diginetica"
]]
;
then
src
=
"diginetica"
else
echo
"Usage: sh data_prepare.sh [diginetica|yoochoose1_4|yoochoose1_64]"
exit
1
fi
echo
"begin to download data"
cd
data
&&
python download.py
$src
mkdir
$dataset
python preprocess.py
--dataset
$src
echo
"begin to convert data (binary -> txt)"
python convert_data.py
--data_dir
$dataset
cat
${
dataset
}
/train.txt |
wc
-l
>>
config.txt
rm
-rf
train
&&
mkdir
train
mv
${
dataset
}
/train.txt train
rm
-rf
test
&&
mkdir test
mv
${
dataset
}
/test.txt
test
models/recall/gru4rec/model.py
浏览文件 @
c718dc5d
...
...
@@ -16,6 +16,7 @@ import paddle.fluid as fluid
from
paddlerec.core.utils
import
envs
from
paddlerec.core.model
import
ModelBase
from
paddlerec.core.metrics
import
Precision
class
Model
(
ModelBase
):
...
...
@@ -81,13 +82,15 @@ class Model(ModelBase):
high
=
self
.
init_high_bound
),
learning_rate
=
self
.
fc_lr_x
))
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
fc
,
label
=
dst_wordseq
)
acc
=
fluid
.
layers
.
accuracy
(
input
=
fc
,
label
=
dst_wordseq
,
k
=
self
.
recall_k
)
# acc = fluid.layers.accuracy(
# input=fc, label=dst_wordseq, k=self.recall_k)
acc
=
Precision
(
input
=
fc
,
label
=
dst_wordseq
,
k
=
self
.
recall_k
)
if
is_infer
:
self
.
_infer_results
[
'
recall
20'
]
=
acc
self
.
_infer_results
[
'
P@
20'
]
=
acc
return
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
self
.
_cost
=
avg_cost
self
.
_metrics
[
"cost"
]
=
avg_cost
self
.
_metrics
[
"
acc
"
]
=
acc
self
.
_metrics
[
"
P@20
"
]
=
acc
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录