From 6fa3457ada1a3991674ec1c0a5ed80eb6c2b8cb0 Mon Sep 17 00:00:00 2001 From: zhangwenhui03 Date: Thu, 14 May 2020 17:08:51 +0800 Subject: [PATCH] add gru4rec infer --- core/trainers/transpiler_trainer.py | 6 ++- models/recall/gru4rec/config.yaml | 10 ++++- models/recall/gru4rec/model.py | 13 ++++++- models/recall/gru4rec/rsc15_infer_reader.py | 42 +++++++++++++++++++++ 4 files changed, 67 insertions(+), 4 deletions(-) create mode 100644 models/recall/gru4rec/rsc15_infer_reader.py diff --git a/core/trainers/transpiler_trainer.py b/core/trainers/transpiler_trainer.py index b5054124..03e252ca 100755 --- a/core/trainers/transpiler_trainer.py +++ b/core/trainers/transpiler_trainer.py @@ -247,6 +247,9 @@ class TranspileTrainer(Trainer): model_list = [(0, envs.get_global_env( 'evaluate_model_path', "", namespace='evaluate'))] + is_return_numpy = envs.get_global_env( + 'is_return_numpy', True, namespace='evaluate') + for (epoch, model_dir) in model_list: print("Begin to infer No.{} model, model_dir: {}".format( epoch, model_dir)) @@ -258,7 +261,8 @@ class TranspileTrainer(Trainer): while True: metrics_rets = self._exe.run( program=program, - fetch_list=metrics_varnames) + fetch_list=metrics_varnames, + return_numpy=is_return_numpy) metrics = [epoch, batch_id] metrics.extend(metrics_rets) diff --git a/models/recall/gru4rec/config.yaml b/models/recall/gru4rec/config.yaml index 2668fb9a..744515b4 100644 --- a/models/recall/gru4rec/config.yaml +++ b/models/recall/gru4rec/config.yaml @@ -12,6 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +evaluate: + reader: + batch_size: 1 + class: "{workspace}/rsc15_infer_reader.py" + test_data_path: "{workspace}/data/train" + is_return_numpy: False + + train: trainer: # for cluster training @@ -19,8 +27,8 @@ train: epochs: 3 workspace: "paddlerec.models.recall.gru4rec" - device: cpu + reader: batch_size: 5 class: "{workspace}/rsc15_reader.py" diff --git a/models/recall/gru4rec/model.py b/models/recall/gru4rec/model.py index 3f4bff27..d77f3c25 100644 --- a/models/recall/gru4rec/model.py +++ b/models/recall/gru4rec/model.py @@ -23,7 +23,7 @@ class Model(ModelBase): def __init__(self, config): ModelBase.__init__(self, config) - def all_vocab_network(self): + def all_vocab_network(self, is_infer=False): """ network definition """ recall_k = envs.get_global_env("hyper_parameters.recall_k", None, self._namespace) vocab_size = envs.get_global_env("hyper_parameters.vocab_size", None, self._namespace) @@ -39,10 +39,16 @@ class Model(ModelBase): dst_wordseq = fluid.data( name="dst_wordseq", shape=[None, 1], dtype="int64", lod_level=1) + if is_infer: + self._infer_data_var = [src_wordseq, dst_wordseq] + self._infer_data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._infer_data_var, capacity=64, use_double_buffer=False, iterable=False) + emb = fluid.embedding( input=src_wordseq, size=[vocab_size, hid_size], param_attr=fluid.ParamAttr( + name="emb", initializer=fluid.initializer.Uniform( low=init_low_bound, high=init_high_bound), learning_rate=emb_lr_x), @@ -70,6 +76,9 @@ class Model(ModelBase): learning_rate=fc_lr_x)) cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq) acc = fluid.layers.accuracy(input=fc, label=dst_wordseq, k=recall_k) + if is_infer: + self._infer_results['recall20'] = acc + return avg_cost = fluid.layers.mean(x=cost) self._data_var.append(src_wordseq) @@ -84,4 +93,4 @@ class Model(ModelBase): def infer_net(self): - pass + self.all_vocab_network(is_infer=True) diff --git a/models/recall/gru4rec/rsc15_infer_reader.py b/models/recall/gru4rec/rsc15_infer_reader.py new file mode 100644 index 00000000..829726e3 --- /dev/null +++ b/models/recall/gru4rec/rsc15_infer_reader.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +from paddlerec.core.reader import Reader +from paddlerec.core.utils import envs + + +class EvaluateReader(Reader): + def init(self): + pass + + def generate_sample(self, line): + """ + Read the data line by line and process it as a dictionary + """ + + def reader(): + """ + This function needs to be implemented by the user, based on data format + """ + l = line.strip().split() + l = [w for w in l] + src_seq = l[:len(l) - 1] + src_seq = [int(e) for e in src_seq] + trg_seq = l[1:] + trg_seq = [int(e) for e in trg_seq] + feature_name = ["src_wordseq", "dst_wordseq"] + yield zip(feature_name, [src_seq] + [trg_seq]) + + return reader -- GitLab