diff --git a/core/trainers/transpiler_trainer.py b/core/trainers/transpiler_trainer.py index b50541241774ad771e7838925f9f3e7ad257d0ac..03e252ca2a0f89ab30cd0c6aecc544918e5e509b 100755 --- a/core/trainers/transpiler_trainer.py +++ b/core/trainers/transpiler_trainer.py @@ -247,6 +247,9 @@ class TranspileTrainer(Trainer): model_list = [(0, envs.get_global_env( 'evaluate_model_path', "", namespace='evaluate'))] + is_return_numpy = envs.get_global_env( + 'is_return_numpy', True, namespace='evaluate') + for (epoch, model_dir) in model_list: print("Begin to infer No.{} model, model_dir: {}".format( epoch, model_dir)) @@ -258,7 +261,8 @@ class TranspileTrainer(Trainer): while True: metrics_rets = self._exe.run( program=program, - fetch_list=metrics_varnames) + fetch_list=metrics_varnames, + return_numpy=is_return_numpy) metrics = [epoch, batch_id] metrics.extend(metrics_rets) diff --git a/models/recall/gru4rec/config.yaml b/models/recall/gru4rec/config.yaml index 2668fb9a55efa3ec411c92c770acc1b8158e7e88..744515b4f453756545b7171f8c7285042c8afca5 100644 --- a/models/recall/gru4rec/config.yaml +++ b/models/recall/gru4rec/config.yaml @@ -12,6 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +evaluate: + reader: + batch_size: 1 + class: "{workspace}/rsc15_infer_reader.py" + test_data_path: "{workspace}/data/train" + is_return_numpy: False + + train: trainer: # for cluster training @@ -19,8 +27,8 @@ train: epochs: 3 workspace: "paddlerec.models.recall.gru4rec" - device: cpu + reader: batch_size: 5 class: "{workspace}/rsc15_reader.py" diff --git a/models/recall/gru4rec/model.py b/models/recall/gru4rec/model.py index 3f4bff278e6af584c3a3952282a99c25b1fe0023..d77f3c254f4f20b12854dbf05af35c64613c4b84 100644 --- a/models/recall/gru4rec/model.py +++ b/models/recall/gru4rec/model.py @@ -23,7 +23,7 @@ class Model(ModelBase): def __init__(self, config): ModelBase.__init__(self, config) - def all_vocab_network(self): + def all_vocab_network(self, is_infer=False): """ network definition """ recall_k = envs.get_global_env("hyper_parameters.recall_k", None, self._namespace) vocab_size = envs.get_global_env("hyper_parameters.vocab_size", None, self._namespace) @@ -39,10 +39,16 @@ class Model(ModelBase): dst_wordseq = fluid.data( name="dst_wordseq", shape=[None, 1], dtype="int64", lod_level=1) + if is_infer: + self._infer_data_var = [src_wordseq, dst_wordseq] + self._infer_data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._infer_data_var, capacity=64, use_double_buffer=False, iterable=False) + emb = fluid.embedding( input=src_wordseq, size=[vocab_size, hid_size], param_attr=fluid.ParamAttr( + name="emb", initializer=fluid.initializer.Uniform( low=init_low_bound, high=init_high_bound), learning_rate=emb_lr_x), @@ -70,6 +76,9 @@ class Model(ModelBase): learning_rate=fc_lr_x)) cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq) acc = fluid.layers.accuracy(input=fc, label=dst_wordseq, k=recall_k) + if is_infer: + self._infer_results['recall20'] = acc + return avg_cost = fluid.layers.mean(x=cost) self._data_var.append(src_wordseq) @@ -84,4 +93,4 @@ class Model(ModelBase): def infer_net(self): - pass + self.all_vocab_network(is_infer=True) diff --git a/models/recall/gru4rec/rsc15_infer_reader.py b/models/recall/gru4rec/rsc15_infer_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..829726e3d8861292266c25dfba4298a1ee2f502a --- /dev/null +++ b/models/recall/gru4rec/rsc15_infer_reader.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +from paddlerec.core.reader import Reader +from paddlerec.core.utils import envs + + +class EvaluateReader(Reader): + def init(self): + pass + + def generate_sample(self, line): + """ + Read the data line by line and process it as a dictionary + """ + + def reader(): + """ + This function needs to be implemented by the user, based on data format + """ + l = line.strip().split() + l = [w for w in l] + src_seq = l[:len(l) - 1] + src_seq = [int(e) for e in src_seq] + trg_seq = l[1:] + trg_seq = [int(e) for e in trg_seq] + feature_name = ["src_wordseq", "dst_wordseq"] + yield zip(feature_name, [src_seq] + [trg_seq]) + + return reader