From 70150f9dc464f6af76f665553578a24b85f7e1c0 Mon Sep 17 00:00:00 2001 From: frankwhzhang Date: Wed, 28 Nov 2018 19:16:26 +0800 Subject: [PATCH] add infer for multiview-simnet --- fluid/PaddleRec/multiview_simnet/README.cn.md | 7 +- fluid/PaddleRec/multiview_simnet/README.md | 7 +- fluid/PaddleRec/multiview_simnet/infer.py | 97 +++++++++++++++++++ 3 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 fluid/PaddleRec/multiview_simnet/infer.py diff --git a/fluid/PaddleRec/multiview_simnet/README.cn.md b/fluid/PaddleRec/multiview_simnet/README.cn.md index bd35ba7d..06df3c32 100644 --- a/fluid/PaddleRec/multiview_simnet/README.cn.md +++ b/fluid/PaddleRec/multiview_simnet/README.cn.md @@ -15,8 +15,13 @@ ```bash python train.py ``` +## +如下 +如下命令行可以获得预测工具的具体选项,`python infer -h`内容可以参考说明 +```bash +python infer.py +``` ## 未来的工作 - 多种pairwise的损失函数会被加入到这个项目中。对于不同视角的特征,用户-项目之间的匹配关系可以使用不同的损失函数进行联合优化。整个模型会在真实数据中进行验证。 -- 推理工具会被加入 - Parallel Executor选项会被加入 - 分布式训练能力会被加入 diff --git a/fluid/PaddleRec/multiview_simnet/README.md b/fluid/PaddleRec/multiview_simnet/README.md index 253f4cba..525946e6 100644 --- a/fluid/PaddleRec/multiview_simnet/README.md +++ b/fluid/PaddleRec/multiview_simnet/README.md @@ -15,8 +15,13 @@ The command line options for training can be listed by `python train.py -h` python train.py ``` +## Infer +The command line options for inference can be listed by `python infer.py -h` +```bash +python infer.py +``` + ## Future work - Multiple types of pairwise loss will be added in this project. For different views of features between a user and an item, multiple losses will be supported. The model will be verified in real world dataset. -- infer will be added - Parallel Executor will be added in this project - Distributed Training will be added diff --git a/fluid/PaddleRec/multiview_simnet/infer.py b/fluid/PaddleRec/multiview_simnet/infer.py new file mode 100644 index 00000000..80262bb5 --- /dev/null +++ b/fluid/PaddleRec/multiview_simnet/infer.py @@ -0,0 +1,97 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import six +import numpy as np +import math +import argparse +import logging +import paddle.fluid as fluid +import paddle +import time +import reader as reader +from nets import MultiviewSimnet, SimpleEncoderFactory + +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger("fluid") +logger.setLevel(logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser("multi-view simnet") + parser.add_argument( + "--train_file", type=str, help="Training file") + parser.add_argument( + "--valid_file", type=str, help="Validation file") + parser.add_argument( + "--epochs", type=int, default=10, help="Number of epochs for training") + parser.add_argument( + "--model_dir", type=str, default='model_output', help="Model output folder") + parser.add_argument( + "--query_slots", type=int, default=1, help="Number of query slots") + parser.add_argument( + "--title_slots", type=int, default=1, help="Number of title slots") + parser.add_argument( + "--query_encoder", type=str, default="bow", help="Encoder module for slot encoding") + parser.add_argument( + "--title_encoder", type=str, default="bow", help="Encoder module for slot encoding") + parser.add_argument( + "--query_encode_dim", type=int, default=128, help="Dimension of query encoder output") + parser.add_argument( + "--title_encode_dim", type=int, default=128, help="Dimension of title encoder output") + parser.add_argument( + "--batch_size", type=int, default=128, help="Batch size for training") + parser.add_argument( + "--embedding_dim", type=int, default=128, help="Default Dimension of Embedding") + parser.add_argument( + "--sparse_feature_dim", type=int, default=1000001, help="Sparse feature hashing space for index processing") + parser.add_argument( + "--hidden_size", type=int, default=128, help="Hidden dim") + return parser.parse_args() + + +def start_infer(args, model_path): + dataset = reader.SyntheticDataset(args.sparse_feature_dim, args.query_slots, + args.title_slots) + test_reader = paddle.batch( + paddle.reader.shuffle( + dataset.valid(), buf_size=args.batch_size * 100), + batch_size=args.batch_size) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + with fluid.scope_guard(fluid.core.Scope()): + infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model( + args.model_dir, exe) + t0 = time.time() + step_id = 0 + feeder = fluid.DataFeeder(program=infer_program, feed_list=feed_target_names, place=place) + for batch_id, data in enumerate(test_reader()): + step_id += 1 + loss_val, correct_val = exe.run(infer_program, + feed=feeder.feed(data), + fetch_list=fetch_vars) + logger.info("TRAIN --> pass: {} batch_id: {} avg_cost: {}, acc: {}" + .format(step_id, batch_id, loss_val, + float(correct_val) / args.batch_size)) + +def main(): + args = parse_args() + start_infer(args, args.model_dir) + +if __name__ == "__main__": + main() -- GitLab