args.py 5.3 KB
Newer Older
C
Chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
# -*- coding=utf-8 -*-
"""
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import six
import argparse


def str2bool(v):
    """
    str2bool
    """
    # because argparse does not support to parse "true, False" as python
    # boolean directly
    return v.lower() in ("true", "t", "1")


class ArgumentGroup(object):
    """
    ArgumentGroup
    """

    def __init__(self, parser, title, des):
        """
        init
        """
        self._group = parser.add_argument_group(title=title, description=des)

    def add_arg(self, name, type, default, help, **kwargs):
        """
        add_arg
        """
        type = str2bool if type == bool else type
        # if type == list: # by dwk
        #     self._group.add_argument("--" + name, nargs='+', type=int)
        # else:
        self._group.add_argument(
            "--" + name,
            default=default,
            type=type,
            help=help + ' Default: %(default)s.',
            **kwargs)


def parse_args():
    """
    parse_args
    """
    # global
    parser = argparse.ArgumentParser("main")
    main_g = ArgumentGroup(parser, "main", "global conf")
    main_g.add_arg("random_seed", int, 0, "random_seed")
    main_g.add_arg("cpu_num", int, 1, "cpu_num")
    main_g.add_arg("is_local", bool, False,
                   "whether to perform local training")
    main_g.add_arg("is_cloud", bool, False, "")
    main_g.add_arg("is_test", bool, False, "")
    main_g.add_arg("sync_mode", str, "async", "distributed traing mode")
    main_g.add_arg("need_trace", bool, False, "")
    main_g.add_arg("need_detail", bool, False, "")

    # model
    model_g = ArgumentGroup(
        parser, "model", "options to init, resume and save model.")
    model_g.add_arg("epoch_num", int, 3, "number of epochs for train")
    model_g.add_arg("batch_size", int, 16, "batch size for train")
    model_g.add_arg("learning_rate", float, 5e-5,
                    "learning rate for global training")

    model_g.add_arg("layer_size", int, 4, "layer size")
    model_g.add_arg("node_nums", int, 26, "tree node nums")
    model_g.add_arg("node_emb_size", int, 64, "node embedding size")
    model_g.add_arg("query_emb_size", int, 768, "input query embedding size")
    model_g.add_arg("neg_sampling_list", list, [
                    1, 2, 3, 4], "nce sample nums at every layer")
    model_g.add_arg("layer_node_num_list", list, [
                    2, 4, 7, 12], "node nums at every layer")
    model_g.add_arg("leaf_node_num", int, 13, "leaf node nums")

    # for infer
    model_g.add_arg("child_nums", int, 2, "child node of ancestor node")
    model_g.add_arg("topK", int, 1, "best recall result nums")

    model_g = ArgumentGroup(
        parser, "path", "files path of data & model.")
    model_g.add_arg("train_files_path", str, "./data/train", "train data path")
    model_g.add_arg("test_files_path", str, "./data/test", "test data path")
    model_g.add_arg("model_files_path", str, "./models", "model data path")

    # build tree and warm up
    model_g.add_arg("build_tree_init_path", str,
                    "./data/gen_tree/demo_fake_input.txt", "build tree embedding path")
    model_g.add_arg("warm-up", bool, False,
                    "warm up, builing new tree.")
    model_g.add_arg("rebuild_tree_per_epochs", int, -1,
                    "re-build tree per epochs, -1 means don't re-building")

    model_g.add_arg("tree_info_init_path", str,
                    "./thirdparty/tree_info.txt", "embedding file path")
    model_g.add_arg("tree_travel_init_path", str,
                    "./thirdparty/travel_list.txt", "TDM tree travel file path")
    model_g.add_arg("tree_layer_init_path", str,
                    "./thirdparty/layer_list.txt", "TDM tree layer file path")
    model_g.add_arg("tree_emb_init_path", str,
                    "./thirdparty/tree_emb.txt", "TDM tree emb file path")

    model_g.add_arg("load_model", bool, False,
                    "whether load model(paddle persistables model)")
    model_g.add_arg("save_init_model", bool, False,
                    "whether save init model(paddle persistables model)")
    model_g.add_arg("init_model_files_path", str, "./models/init_model",
                    "init model params by paddle model files for training")
    model_g.add_arg("infer_model_files_path", str, "./models/init_model",
                    "model files path for infer")

    args = parser.parse_args()
    return args


def print_arguments(args):
    """
    print arguments
    """
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(six.iteritems(vars(args))):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')