config.py 2.5 KB
Newer Older
1 2 3 4 5 6
import argparse
"""
global params
"""


7 8 9 10 11 12
def boolean_string(s):
    if s.lower() not in {'false', 'true'}:
        raise ValueError('Not a valid boolean string')
    return s.lower() == 'true'


13 14 15 16 17 18 19 20 21 22 23 24
def parse_args():
    parser = argparse.ArgumentParser(description="PaddleFluid DCN demo")
    parser.add_argument(
        '--train_data_dir',
        type=str,
        default='data/train',
        help='The path of train data')
    parser.add_argument(
        '--test_valid_data_dir',
        type=str,
        default='data/test_valid',
        help='The path of test and valid data')
25 26 27 28 29 30 31 32 33 34
    parser.add_argument(
        '--vocab_dir',
        type=str,
        default='data/vocab',
        help='The path of generated vocabs')
    parser.add_argument(
        '--cat_feat_num',
        type=str,
        default='data/cat_feature_num.txt',
        help='The path of generated cat_feature_num.txt')
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    parser.add_argument(
        '--batch_size', type=int, default=512, help="Batch size")
    parser.add_argument(
        '--steps',
        type=int,
        default=150000,
        help="Early stop steps in training. If set, num_epoch will not work")
    parser.add_argument('--num_epoch', type=int, default=2, help="train epoch")
    parser.add_argument(
        '--model_output_dir',
        type=str,
        default='models',
        help='The path for model to store')
    parser.add_argument(
        '--num_thread', type=int, default=20, help='The number of threads')
    parser.add_argument('--test_epoch', type=str, default='1')
    parser.add_argument(
        '--dnn_hidden_units',
        nargs='+',
        type=int,
        default=[1024, 1024],
        help='DNN layers and hidden units')
    parser.add_argument(
        '--cross_num',
        type=int,
        default=6,
        help='The number of Cross network layers')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument(
        '--l2_reg_cross',
        type=float,
        default=1e-5,
        help='Cross net l2 regularizer coefficient')
    parser.add_argument(
        '--use_bn',
70
        type=boolean_string,
71 72
        default=True,
        help='Whether use batch norm in dnn part')
73 74 75 76 77 78
    parser.add_argument(
        '--is_sparse',
        action='store_true',
        required=False,
        default=False,
        help='embedding will use sparse or not, (default: False)')
79 80 81 82 83
    parser.add_argument(
        '--clip_by_norm', type=float, default=100.0, help="gradient clip norm")
    parser.add_argument('--print_steps', type=int, default=100)

    return parser.parse_args()