utils.py 3.5 KB
Newer Older
1 2 3 4 5
import logging
import os
import argparse
from collections import defaultdict

6
logger = logging.getLogger("paddle")
7 8 9 10 11
logger.setLevel(logging.INFO)


def parse_train_cmd():
    parser = argparse.ArgumentParser(
C
caoying03 已提交
12
        description="PaddlePaddle text classification example.")
13 14 15
    parser.add_argument(
        "--nn_type",
        type=str,
C
caoying03 已提交
16 17
        help=("A flag that defines which type of network to use, "
              "available: [dnn, cnn]."),
18 19 20 21 22
        default="dnn")
    parser.add_argument(
        "--train_data_dir",
        type=str,
        required=False,
C
caoying03 已提交
23 24
        help=("The path of training dataset (default: None). If this parameter "
              "is not set, paddle.dataset.imdb will be used."),
25 26 27 28 29
        default=None)
    parser.add_argument(
        "--test_data_dir",
        type=str,
        required=False,
C
caoying03 已提交
30 31
        help=("The path of testing dataset (default: None). If this parameter "
              "is not set, paddle.dataset.imdb will be used."),
32 33 34 35 36
        default=None)
    parser.add_argument(
        "--word_dict",
        type=str,
        required=False,
C
caoying03 已提交
37 38 39 40
        help=("The path of word dictionary (default: None). If this parameter "
              "is not set, paddle.dataset.imdb will be used. If this parameter "
              "is set, but the file does not exist, word dictionay "
              "will be built from the training data automatically."),
41 42 43 44 45
        default=None)
    parser.add_argument(
        "--label_dict",
        type=str,
        required=False,
C
caoying03 已提交
46 47 48 49
        help=("The path of label dictionay (default: None).If this parameter "
              "is not set, paddle.dataset.imdb will be used. If this parameter "
              "is set, but the file does not exist, word dictionay "
              "will be built from the training data automatically."),
50 51 52 53 54
        default=None)
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
C
caoying03 已提交
55
        help="The number of training examples in one forward/backward pass.")
56
    parser.add_argument(
C
caoying03 已提交
57 58 59 60
        "--num_passes",
        type=int,
        default=10,
        help="The number of passes to train the model.")
61 62 63 64
    parser.add_argument(
        "--model_save_dir",
        type=str,
        required=False,
C
caoying03 已提交
65
        help=("The path to save the trained models."),
66
        default="models")
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

    return parser.parse_args()


def build_dict(data_dir,
               save_path,
               use_col=0,
               cutoff_fre=0,
               insert_extra_words=[]):
    values = defaultdict(int)

    for file_name in os.listdir(data_dir):
        file_path = os.path.join(data_dir, file_name)
        if not os.path.isfile(file_path):
            continue
        with open(file_path, "r") as fdata:
            for line in fdata:
                line_splits = line.strip().split("\t")
                if len(line_splits) < use_col: continue
                for w in line_splits[use_col].split():
                    values[w] += 1

    with open(save_path, "w") as f:
        for w in insert_extra_words:
            f.write("%s\t-1\n" % (w))

        for v, count in sorted(
                values.iteritems(), key=lambda x: x[1], reverse=True):
            if count < cutoff_fre:
                break
            f.write("%s\t%d\n" % (v, count))


def load_dict(dict_path):
    return dict((line.strip().split("\t")[0], idx)
                for idx, line in enumerate(open(dict_path, "r").readlines()))


def load_reverse_dict(dict_path):
    return dict((idx, line.strip().split("\t")[0])
                for idx, line in enumerate(open(dict_path, "r").readlines()))