args.py 2.9 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
import argparse

def parse_args():
        parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
        parser.add_argument(
            '--train_data_path',
            type=str,
            default='./data/raw/train.txt',
            help="The path of training dataset")
        parser.add_argument(
            '--test_data_path',
            type=str,
            default='./data/raw/valid.txt',
            help="The path of testing dataset")
        parser.add_argument(
            '--batch_size',
            type=int,
            default=1000,
            help="The size of mini-batch (default:1000)")
        parser.add_argument(
            '--embedding_size',
            type=int,
            default=10,
            help="The size for embedding layer (default:10)")
        parser.add_argument(
            '--num_passes',
            type=int,
            default=10,
            help="The number of passes to train (default: 10)")
        parser.add_argument(
            '--model_output_dir',
            type=str,
            default='models',
            help='The path for model to store (default: models)')
        parser.add_argument(
            '--sparse_feature_dim',
            type=int,
            default=1000001,
            help='sparse feature hashing space for index processing')
        parser.add_argument(
            '--is_local',
            type=int,
            default=1,
            help='Local train or distributed train (default: 1)')
        parser.add_argument(
            '--cloud_train',
            type=int,
            default=0,
            help='Local train or distributed train on paddlecloud (default: 0)')
        parser.add_argument(
            '--async_mode',
            action='store_true',
            default=False,
            help='Whether start pserver in async mode to support ASGD')
        parser.add_argument(
            '--no_split_var',
            action='store_true',
            default=False,
            help='Whether split variables into blocks when update_method is pserver')
        parser.add_argument(
            '--role',
            type=str,
            default='pserver', # trainer or pserver
            help='The path for model to store (default: models)')
        parser.add_argument(
            '--endpoints',
            type=str,
            default='127.0.0.1:6000',
            help='The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001')
        parser.add_argument(
            '--current_endpoint',
            type=str,
            default='127.0.0.1:6000',
            help='The path for model to store (default: 127.0.0.1:6000)')
        parser.add_argument(
            '--trainer_id',
            type=int,
            default=0,
            help='The path for model to store (default: models)')
        parser.add_argument(
            '--trainers',
            type=int,
            default=1,
            help='The num of trianers, (default: 1)')
        return parser.parse_args()