Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
6fa3457a
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
6fa3457a
编写于
5月 14, 2020
作者:
Z
zhangwenhui03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add gru4rec infer
上级
3a3a2356
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
67 addition
and
4 deletion
+67
-4
core/trainers/transpiler_trainer.py
core/trainers/transpiler_trainer.py
+5
-1
models/recall/gru4rec/config.yaml
models/recall/gru4rec/config.yaml
+9
-1
models/recall/gru4rec/model.py
models/recall/gru4rec/model.py
+11
-2
models/recall/gru4rec/rsc15_infer_reader.py
models/recall/gru4rec/rsc15_infer_reader.py
+42
-0
未找到文件。
core/trainers/transpiler_trainer.py
浏览文件 @
6fa3457a
...
...
@@ -247,6 +247,9 @@ class TranspileTrainer(Trainer):
model_list
=
[(
0
,
envs
.
get_global_env
(
'evaluate_model_path'
,
""
,
namespace
=
'evaluate'
))]
is_return_numpy
=
envs
.
get_global_env
(
'is_return_numpy'
,
True
,
namespace
=
'evaluate'
)
for
(
epoch
,
model_dir
)
in
model_list
:
print
(
"Begin to infer No.{} model, model_dir: {}"
.
format
(
epoch
,
model_dir
))
...
...
@@ -258,7 +261,8 @@ class TranspileTrainer(Trainer):
while
True
:
metrics_rets
=
self
.
_exe
.
run
(
program
=
program
,
fetch_list
=
metrics_varnames
)
fetch_list
=
metrics_varnames
,
return_numpy
=
is_return_numpy
)
metrics
=
[
epoch
,
batch_id
]
metrics
.
extend
(
metrics_rets
)
...
...
models/recall/gru4rec/config.yaml
浏览文件 @
6fa3457a
...
...
@@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate
:
reader
:
batch_size
:
1
class
:
"
{workspace}/rsc15_infer_reader.py"
test_data_path
:
"
{workspace}/data/train"
is_return_numpy
:
False
train
:
trainer
:
# for cluster training
...
...
@@ -19,8 +27,8 @@ train:
epochs
:
3
workspace
:
"
paddlerec.models.recall.gru4rec"
device
:
cpu
reader
:
batch_size
:
5
class
:
"
{workspace}/rsc15_reader.py"
...
...
models/recall/gru4rec/model.py
浏览文件 @
6fa3457a
...
...
@@ -23,7 +23,7 @@ class Model(ModelBase):
def
__init__
(
self
,
config
):
ModelBase
.
__init__
(
self
,
config
)
def
all_vocab_network
(
self
):
def
all_vocab_network
(
self
,
is_infer
=
False
):
""" network definition """
recall_k
=
envs
.
get_global_env
(
"hyper_parameters.recall_k"
,
None
,
self
.
_namespace
)
vocab_size
=
envs
.
get_global_env
(
"hyper_parameters.vocab_size"
,
None
,
self
.
_namespace
)
...
...
@@ -39,10 +39,16 @@ class Model(ModelBase):
dst_wordseq
=
fluid
.
data
(
name
=
"dst_wordseq"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
,
lod_level
=
1
)
if
is_infer
:
self
.
_infer_data_var
=
[
src_wordseq
,
dst_wordseq
]
self
.
_infer_data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
self
.
_infer_data_var
,
capacity
=
64
,
use_double_buffer
=
False
,
iterable
=
False
)
emb
=
fluid
.
embedding
(
input
=
src_wordseq
,
size
=
[
vocab_size
,
hid_size
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"emb"
,
initializer
=
fluid
.
initializer
.
Uniform
(
low
=
init_low_bound
,
high
=
init_high_bound
),
learning_rate
=
emb_lr_x
),
...
...
@@ -70,6 +76,9 @@ class Model(ModelBase):
learning_rate
=
fc_lr_x
))
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
fc
,
label
=
dst_wordseq
)
acc
=
fluid
.
layers
.
accuracy
(
input
=
fc
,
label
=
dst_wordseq
,
k
=
recall_k
)
if
is_infer
:
self
.
_infer_results
[
'recall20'
]
=
acc
return
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
self
.
_data_var
.
append
(
src_wordseq
)
...
...
@@ -84,4 +93,4 @@ class Model(ModelBase):
def
infer_net
(
self
):
pass
self
.
all_vocab_network
(
is_infer
=
True
)
models/recall/gru4rec/rsc15_infer_reader.py
0 → 100644
浏览文件 @
6fa3457a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
paddlerec.core.reader
import
Reader
from
paddlerec.core.utils
import
envs
class
EvaluateReader
(
Reader
):
def
init
(
self
):
pass
def
generate_sample
(
self
,
line
):
"""
Read the data line by line and process it as a dictionary
"""
def
reader
():
"""
This function needs to be implemented by the user, based on data format
"""
l
=
line
.
strip
().
split
()
l
=
[
w
for
w
in
l
]
src_seq
=
l
[:
len
(
l
)
-
1
]
src_seq
=
[
int
(
e
)
for
e
in
src_seq
]
trg_seq
=
l
[
1
:]
trg_seq
=
[
int
(
e
)
for
e
in
trg_seq
]
feature_name
=
[
"src_wordseq"
,
"dst_wordseq"
]
yield
zip
(
feature_name
,
[
src_seq
]
+
[
trg_seq
])
return
reader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录