train.py 3.7 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
D
dongdaxiang 已提交
14 15 16 17 18 19 20
import os
import sys
import argparse
import logging
import paddle.fluid as fluid
import paddle
import reader as reader
D
dongdaxiang 已提交
21
from nets import SequenceSemanticRetrieval
D
dongdaxiang 已提交
22

23
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
D
dongdaxiang 已提交
24 25 26 27
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

def parse_args():
D
dongdaxiang 已提交
28
    parser = argparse.ArgumentParser("sequence semantic retrieval")
29 30 31 32 33 34 35 36 37 38
    parser.add_argument("--train_file", type=str, help="Training file")
    parser.add_argument("--valid_file", type=str, help="Validation file")
    parser.add_argument(
        "--epochs", type=int, default=10, help="Number of epochs for training")
    parser.add_argument(
        "--model_output_dir",
        type=str,
        default='model_output',
        help="Model output folder")
    parser.add_argument(
D
dongdaxiang 已提交
39
        "--sequence_encode_dim",
40 41
        type=int,
        default=128,
D
dongdaxiang 已提交
42
        help="Dimension of sequence encoder output")
43
    parser.add_argument(
D
dongdaxiang 已提交
44
        "--matching_dim",
45 46
        type=int,
        default=128,
D
dongdaxiang 已提交
47
        help="Dimension of hidden layer")
48 49 50 51 52 53 54
    parser.add_argument(
        "--batch_size", type=int, default=128, help="Batch size for training")
    parser.add_argument(
        "--embedding_dim",
        type=int,
        default=128,
        help="Default Dimension of Embedding")
D
dongdaxiang 已提交
55 56 57
    return parser.parse_args()

def start_train(args):
D
dongdaxiang 已提交
58 59 60 61 62
    y_vocab = reader.YoochooseVocab()
    y_vocab.load([args.train_file])

    logger.info("Load yoochoose vocabulary size: {}".format(len(y_vocab.get_vocab())))
    y_data = reader.YoochooseDataset(y_vocab)
D
dongdaxiang 已提交
63 64
    train_reader = paddle.batch(
        paddle.reader.shuffle(
D
dongdaxiang 已提交
65
            y_data.train([args.train_file]), buf_size=args.batch_size * 100),
D
dongdaxiang 已提交
66 67
        batch_size=args.batch_size)
    place = fluid.CPUPlace()
D
dongdaxiang 已提交
68 69 70 71
    ssr = SequenceSemanticRetrieval(
        len(y_vocab.get_vocab()), args.embedding_dim, args.matching_dim
    )
    input_data, user_rep, item_rep, avg_cost, acc = ssr.train()
D
dongdaxiang 已提交
72 73 74 75
    optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
    optimizer.minimize(avg_cost)
    startup_program = fluid.default_startup_program()
    loop_program = fluid.default_main_program()
D
dongdaxiang 已提交
76 77
    data_list = [var.name for var in input_data]
    feeder = fluid.DataFeeder(feed_list=data_list, place=place)
D
dongdaxiang 已提交
78 79 80 81 82
    exe = fluid.Executor(place)
    exe.run(startup_program)

    for pass_id in range(args.epochs):
        for batch_id, data in enumerate(train_reader()):
83 84
            loss_val, correct_val = exe.run(loop_program,
                                            feed=feeder.feed(data),
D
dongdaxiang 已提交
85 86 87 88 89 90 91
                                            fetch_list=[avg_cost, acc])
            logger.info("Train --> pass: {} batch_id: {} avg_cost: {}, acc: {}".
                        format(pass_id, batch_id, loss_val, 
                               float(correct_val) / args.batch_size))
        fluid.io.save_inference_model(args.model_output_dir, 
                                      [var.name for val in input_data],
                                      [user_rep, item_rep, avg_cost, acc], exe)
D
dongdaxiang 已提交
92 93 94 95 96 97 98

def main():
    args = parse_args()
    start_train(args)

if __name__ == "__main__":
    main()
D
dongdaxiang 已提交
99