From 3b59e3c42b30b86bba30a0fccb007024a802868c Mon Sep 17 00:00:00 2001 From: zhangwenhui03 Date: Fri, 15 May 2020 14:35:12 +0800 Subject: [PATCH] add ssr infer --- models/recall/ssr/config.yaml | 8 +++++ models/recall/ssr/model.py | 42 ++++++++++++++++++++++++- models/recall/ssr/ssr_infer_reader.py | 44 +++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 models/recall/ssr/ssr_infer_reader.py diff --git a/models/recall/ssr/config.yaml b/models/recall/ssr/config.yaml index 0682c652..b7879466 100644 --- a/models/recall/ssr/config.yaml +++ b/models/recall/ssr/config.yaml @@ -12,6 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. + +evaluate: + reader: + batch_size: 1 + class: "{workspace}/ssr_infer_reader.py" + test_data_path: "{workspace}/data/train" + is_return_numpy: True + train: trainer: # for cluster training diff --git a/models/recall/ssr/model.py b/models/recall/ssr/model.py index b2fa5c89..3f93c8a1 100644 --- a/models/recall/ssr/model.py +++ b/models/recall/ssr/model.py @@ -134,5 +134,45 @@ class Model(ModelBase): self.train() + def infer(self): + vocab_size = envs.get_global_env("hyper_parameters.vocab_size", None, self._namespace) + emb_dim = envs.get_global_env("hyper_parameters.emb_dim", None, self._namespace) + hidden_size = envs.get_global_env("hyper_parameters.hidden_size", None, self._namespace) + + user_data = fluid.data( + name="user", shape=[None, 1], dtype="int64", lod_level=1) + all_item_data = fluid.data( + name="all_item", shape=[None, vocab_size], dtype="int64") + pos_label = fluid.data(name="pos_label", shape=[None, 1], dtype="int64") + self._infer_data_var = [user_data, all_item_data, pos_label] + self._infer_data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._infer_data_var, capacity=64, use_double_buffer=False, iterable=False) + + user_emb = fluid.embedding( + input=user_data, size=[vocab_size, emb_dim], param_attr="emb.item") + all_item_emb = fluid.embedding( + input=all_item_data, size=[vocab_size, emb_dim], param_attr="emb.item") + all_item_emb_re = fluid.layers.reshape(x=all_item_emb, shape=[-1, emb_dim]) + + user_encoder = GrnnEncoder() + user_enc = user_encoder.forward(user_emb) + user_hid = fluid.layers.fc(input=user_enc, + size=hidden_size, + param_attr='user.w', + bias_attr="user.b") + user_exp = fluid.layers.expand(x=user_hid, expand_times=[1, vocab_size]) + user_re = fluid.layers.reshape(x=user_exp, shape=[-1, hidden_size]) + + all_item_hid = fluid.layers.fc(input=all_item_emb_re, + size=hidden_size, + param_attr='item.w', + bias_attr="item.b") + cos_item = fluid.layers.cos_sim(X=all_item_hid, Y=user_re) + all_pre_ = fluid.layers.reshape(x=cos_item, shape=[-1, vocab_size]) + acc = fluid.layers.accuracy(input=all_pre_, label=pos_label, k=20) + + self._infer_results['recall20'] = acc + + def infer_net(self): - pass + self.infer() diff --git a/models/recall/ssr/ssr_infer_reader.py b/models/recall/ssr/ssr_infer_reader.py new file mode 100644 index 00000000..b9e5f726 --- /dev/null +++ b/models/recall/ssr/ssr_infer_reader.py @@ -0,0 +1,44 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +from paddlerec.core.reader import Reader +from paddlerec.core.utils import envs +import random +import numpy as np + + +class EvaluateReader(Reader): + def init(self): + self.vocab_size = envs.get_global_env("vocab_size", 10, "train.model.hyper_parameters") + + + def generate_sample(self, line): + """ + Read the data line by line and process it as a dictionary + """ + + def reader(): + """ + This function needs to be implemented by the user, based on data format + """ + ids = line.strip().split() + conv_ids = [int(i) for i in ids] + boundary = len(ids) - 1 + src = conv_ids[:boundary] + pos_tgt = [conv_ids[boundary]] + feature_name = ["user", "all_item", "p_item"] + yield zip(feature_name, [src] + [np.arange(self.vocab_size).astype("int64").tolist()]+ [pos_tgt]) + + return reader -- GitLab