提交 3865e855 编写于 作者: Z zhangwenhui03 提交者: tangwei

add ssr infer

上级 07f946ab
......@@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate:
reader:
batch_size: 1
class: "{workspace}/ssr_infer_reader.py"
test_data_path: "{workspace}/data/train"
is_return_numpy: True
train:
trainer:
# for cluster training
......
......@@ -134,5 +134,45 @@ class Model(ModelBase):
self.train()
def infer(self):
vocab_size = envs.get_global_env("hyper_parameters.vocab_size", None, self._namespace)
emb_dim = envs.get_global_env("hyper_parameters.emb_dim", None, self._namespace)
hidden_size = envs.get_global_env("hyper_parameters.hidden_size", None, self._namespace)
user_data = fluid.data(
name="user", shape=[None, 1], dtype="int64", lod_level=1)
all_item_data = fluid.data(
name="all_item", shape=[None, vocab_size], dtype="int64")
pos_label = fluid.data(name="pos_label", shape=[None, 1], dtype="int64")
self._infer_data_var = [user_data, all_item_data, pos_label]
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var, capacity=64, use_double_buffer=False, iterable=False)
user_emb = fluid.embedding(
input=user_data, size=[vocab_size, emb_dim], param_attr="emb.item")
all_item_emb = fluid.embedding(
input=all_item_data, size=[vocab_size, emb_dim], param_attr="emb.item")
all_item_emb_re = fluid.layers.reshape(x=all_item_emb, shape=[-1, emb_dim])
user_encoder = GrnnEncoder()
user_enc = user_encoder.forward(user_emb)
user_hid = fluid.layers.fc(input=user_enc,
size=hidden_size,
param_attr='user.w',
bias_attr="user.b")
user_exp = fluid.layers.expand(x=user_hid, expand_times=[1, vocab_size])
user_re = fluid.layers.reshape(x=user_exp, shape=[-1, hidden_size])
all_item_hid = fluid.layers.fc(input=all_item_emb_re,
size=hidden_size,
param_attr='item.w',
bias_attr="item.b")
cos_item = fluid.layers.cos_sim(X=all_item_hid, Y=user_re)
all_pre_ = fluid.layers.reshape(x=cos_item, shape=[-1, vocab_size])
acc = fluid.layers.accuracy(input=all_pre_, label=pos_label, k=20)
self._infer_results['recall20'] = acc
def infer_net(self):
pass
self.infer()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
import random
import numpy as np
class EvaluateReader(Reader):
def init(self):
self.vocab_size = envs.get_global_env("vocab_size", 10, "train.model.hyper_parameters")
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
ids = line.strip().split()
conv_ids = [int(i) for i in ids]
boundary = len(ids) - 1
src = conv_ids[:boundary]
pos_tgt = [conv_ids[boundary]]
feature_name = ["user", "all_item", "p_item"]
yield zip(feature_name, [src] + [np.arange(self.vocab_size).astype("int64").tolist()]+ [pos_tgt])
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册