train.py 4.4 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
D
dongdaxiang 已提交
14 15
import os
import sys
Z
add ssr  
zhangwenhui03 已提交
16
import time
D
dongdaxiang 已提交
17 18 19 20
import argparse
import logging
import paddle.fluid as fluid
import paddle
Z
add ssr  
zhangwenhui03 已提交
21 22
import utils
import numpy as np
D
dongdaxiang 已提交
23
from nets import SequenceSemanticRetrieval
D
dongdaxiang 已提交
24

25
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
D
dongdaxiang 已提交
26 27 28
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

Z
add ssr  
zhangwenhui03 已提交
29

D
dongdaxiang 已提交
30
def parse_args():
D
dongdaxiang 已提交
31
    parser = argparse.ArgumentParser("sequence semantic retrieval")
32
    parser.add_argument(
Z
add ssr  
zhangwenhui03 已提交
33 34 35
        "--train_dir", type=str, default='train_data', help="Training file")
    parser.add_argument(
        "--base_lr", type=float, default=0.01, help="learning rate")
36
    parser.add_argument(
Z
add ssr  
zhangwenhui03 已提交
37
        '--vocab_path',
38
        type=str,
Z
add ssr  
zhangwenhui03 已提交
39 40 41 42 43 44
        default='vocab.txt',
        help='vocab file address')
    parser.add_argument(
        "--epochs", type=int, default=10, help="Number of epochs")
    parser.add_argument(
        '--parallel', type=int, default=0, help='whether parallel')
45
    parser.add_argument(
Z
add ssr  
zhangwenhui03 已提交
46
        '--use_cuda', type=int, default=0, help='whether use gpu')
47
    parser.add_argument(
Z
add ssr  
zhangwenhui03 已提交
48
        '--print_batch', type=int, default=10, help='num of print batch')
49
    parser.add_argument(
Z
add ssr  
zhangwenhui03 已提交
50
        '--model_dir', type=str, default='model_output', help='model dir')
51
    parser.add_argument(
Z
add ssr  
zhangwenhui03 已提交
52 53 54 55 56 57
        "--hidden_size", type=int, default=128, help="hidden size")
    parser.add_argument("--batch_size", type=int, default=50, help="batch size")
    parser.add_argument(
        "--embedding_dim", type=int, default=128, help="embedding dim")
    parser.add_argument(
        '--num_devices', type=int, default=1, help='Number of GPU devices')
D
dongdaxiang 已提交
58 59
    return parser.parse_args()

D
dongdaxiang 已提交
60

Z
add ssr  
zhangwenhui03 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
def get_cards(args):
    return args.num_devices


def train(args):
    use_cuda = True if args.use_cuda else False
    parallel = True if args.parallel else False
    print("use_cuda:", use_cuda, "parallel:", parallel)
    train_reader, vocab_size = utils.construct_train_data(
        args.train_dir, args.vocab_path, args.batch_size * get_cards(args))
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    ssr = SequenceSemanticRetrieval(vocab_size, args.embedding_dim,
                                    args.hidden_size)
    # Train program
    train_input_data, cos_pos, avg_cost, acc = ssr.train()

    # Optimization to minimize lost
    optimizer = fluid.optimizer.Adagrad(learning_rate=args.base_lr)
D
dongdaxiang 已提交
79
    optimizer.minimize(avg_cost)
Z
add ssr  
zhangwenhui03 已提交
80 81

    data_list = [var.name for var in train_input_data]
D
dongdaxiang 已提交
82
    feeder = fluid.DataFeeder(feed_list=data_list, place=place)
D
dongdaxiang 已提交
83
    exe = fluid.Executor(place)
Z
add ssr  
zhangwenhui03 已提交
84 85 86 87 88 89
    exe.run(fluid.default_startup_program())
    if parallel:
        train_exe = fluid.ParallelExecutor(
            use_cuda=use_cuda, loss_name=avg_cost.name)
    else:
        train_exe = exe
D
dongdaxiang 已提交
90

Z
add ssr  
zhangwenhui03 已提交
91
    total_time = 0.0
D
dongdaxiang 已提交
92
    for pass_id in range(args.epochs):
Z
add ssr  
zhangwenhui03 已提交
93 94 95 96
        epoch_idx = pass_id + 1
        print("epoch_%d start" % epoch_idx)
        t0 = time.time()
        i = 0
D
dongdaxiang 已提交
97
        for batch_id, data in enumerate(train_reader()):
Z
add ssr  
zhangwenhui03 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
            i += 1
            loss_val, correct_val = train_exe.run(
                feed=feeder.feed(data), fetch_list=[avg_cost.name, acc.name])
            if i % args.print_batch == 0:
                logger.info(
                    "Train --> pass: {} batch_id: {} avg_cost: {}, acc: {}".
                    format(pass_id, batch_id,
                           np.mean(loss_val),
                           float(np.mean(correct_val)) / args.batch_size))
        t1 = time.time()
        total_time += t1 - t0
        print("epoch:%d num_steps:%d time_cost(s):%f" %
              (epoch_idx, i, total_time / epoch_idx))
        save_dir = "%s/epoch_%d" % (args.model_dir, epoch_idx)
        fluid.io.save_params(executor=exe, dirname=save_dir)
        print("model saved in %s" % save_dir)

D
dongdaxiang 已提交
115 116 117

def main():
    args = parse_args()
Z
add ssr  
zhangwenhui03 已提交
118 119
    train(args)

D
dongdaxiang 已提交
120 121 122

if __name__ == "__main__":
    main()