args.py 4.8 KB
Newer Older
Q
qiuxuezhong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import distutils.util


def parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
X
xuezhong 已提交
25 26 27 28
    parser.add_argument(
        '--prepare',
        action='store_true',
        help='create the directories, prepare the vocabulary and embeddings')
Q
qiuxuezhong 已提交
29
    parser.add_argument('--train', action='store_true', help='train the model')
X
xuezhong 已提交
30 31 32 33 34 35 36 37 38 39
    parser.add_argument('--evaluate', action='store_true', help='evaluate the model on dev set')
    parser.add_argument('--predict', action='store_true',
                        help='predict the answers for test set with trained model')

    parser.add_argument("--embed_size", type=int, default=300,
                        help="The dimension of embedding table. (default: %(default)d)")
    parser.add_argument("--hidden_size", type=int, default=150,
                        help="The size of rnn hidden unit. (default: %(default)d)")
    parser.add_argument("--learning_rate", type=float, default=0.001,
                        help="Learning rate used to train the model. (default: %(default)f)")
Q
qiuxuezhong 已提交
40
    parser.add_argument('--optim', default='adam', help='optimizer type')
X
xuezhong 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54
    parser.add_argument("--weight_decay", type=float, default=0.0001,
                        help="Weight decay. (default: %(default)f)")

    parser.add_argument('--drop_rate', type=float, default=0.0, help="Dropout probability")
    parser.add_argument('--random_seed', type=int, default=123)
    parser.add_argument("--batch_size", type=int, default=32,
                        help="The sequence number of a mini-batch data. (default: %(default)d)")
    parser.add_argument("--pass_num", type=int, default=5,
                        help="The number epochs to train. (default: %(default)d)")
    parser.add_argument("--use_gpu", type=distutils.util.strtobool, default=True,
                        help="Whether to use gpu. (default: %(default)d)")
    parser.add_argument("--log_interval", type=int, default=50,
                        help="log the train loss every n batches. (default: %(default)d)")

Q
qiuxuezhong 已提交
55 56 57
    parser.add_argument('--max_p_num', type=int, default=5)
    parser.add_argument('--max_a_len', type=int, default=200)
    parser.add_argument('--max_p_len', type=int, default=500)
X
xuezhong 已提交
58
    parser.add_argument('--max_q_len', type=int, default=60)
Q
qiuxuezhong 已提交
59
    parser.add_argument('--doc_num', type=int, default=5)
X
xuezhong 已提交
60

Y
Yibing Liu 已提交
61 62
    parser.add_argument('--vocab_dir', default='../data/vocab', help='vocabulary')
    parser.add_argument("--save_dir", type=str, default="../data/models",
X
xuezhong 已提交
63 64 65 66 67 68 69
                        help="Specify the path to save trained models.")
    parser.add_argument("--save_interval", type=int, default=1,
                        help="Save the trained model every n passes. (default: %(default)d)")
    parser.add_argument("--load_dir", type=str, default="",
                        help="Specify the path to load trained models.")
    parser.add_argument('--log_path',
                        help='path of the log file. If not set, logs are printed to console')
Y
Yibing Liu 已提交
70
    parser.add_argument('--result_dir', default='../data/results/',
X
xuezhong 已提交
71 72 73 74 75
                        help='the dir to output the results')
    parser.add_argument('--result_name', default='test_result',
                        help='the file name of the predicted results')

    parser.add_argument('--trainset', nargs='+',
Y
Yibing Liu 已提交
76
                        default=['../data/demo/trainset/search.train.json'],
X
xuezhong 已提交
77 78
                        help='train dataset')
    parser.add_argument('--devset', nargs='+',
Y
Yibing Liu 已提交
79
                        default=['../data/demo/devset/search.dev.json'],
X
xuezhong 已提交
80 81
                        help='dev dataset')
    parser.add_argument('--testset', nargs='+',
Y
Yibing Liu 已提交
82
                        default=['../data/demo/testset/search.test.json'],
X
xuezhong 已提交
83 84 85 86 87 88 89
                        help='test dataset')

    parser.add_argument("--enable_ce", action='store_true',
                        help="If set, run the task with continuous evaluation logs.")
    parser.add_argument('--para_print', action='store_true', help="Print debug info")
    parser.add_argument("--dev_interval", type=int, default=-1,
                        help="evaluate on dev set loss every n batches. (default: %(default)d)")
Q
qiuxuezhong 已提交
90 91
    args = parser.parse_args()
    return args