utils.py 3.5 KB
Newer Older
1 2 3 4 5
import logging
import os
import argparse
from collections import defaultdict

6
logger = logging.getLogger("paddle")
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
logger.setLevel(logging.INFO)


def parse_train_cmd():
    parser = argparse.ArgumentParser(
        description="PaddlePaddle text classification demo")
    parser.add_argument(
        "--nn_type",
        type=str,
        help="define which type of network to use, available: [dnn, cnn]",
        default="dnn")
    parser.add_argument(
        "--train_data_dir",
        type=str,
        required=False,
        help=("path of training dataset (default: None). "
              "if this parameter is not set, "
              "paddle.dataset.imdb will be used."),
        default=None)
    parser.add_argument(
        "--test_data_dir",
        type=str,
        required=False,
        help=("path of testing dataset (default: None). "
              "if this parameter is not set, "
              "paddle.dataset.imdb will be used."),
        default=None)
    parser.add_argument(
        "--word_dict",
        type=str,
        required=False,
        help=("path of word dictionary (default: None)."
              "if this parameter is not set, paddle.dataset.imdb will be used."
              "if this parameter is set, but the file does not exist, "
              "word dictionay will be built from "
              "the training data automatically."),
        default=None)
    parser.add_argument(
        "--label_dict",
        type=str,
        required=False,
        help=("path of label dictionay (default: None)."
              "if this parameter is not set, paddle.dataset.imdb will be used."
              "if this parameter is set, but the file does not exist, "
              "word dictionay will be built from "
              "the training data automatically."),
        default=None)
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="the number of training examples in one forward/backward pass")
    parser.add_argument(
        "--num_passes", type=int, default=10, help="number of passes to train")
61 62 63 64 65 66
    parser.add_argument(
        "--model_save_dir",
        type=str,
        required=False,
        help=("path to save the trained models."),
        default="models")
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

    return parser.parse_args()


def build_dict(data_dir,
               save_path,
               use_col=0,
               cutoff_fre=0,
               insert_extra_words=[]):
    values = defaultdict(int)

    for file_name in os.listdir(data_dir):
        file_path = os.path.join(data_dir, file_name)
        if not os.path.isfile(file_path):
            continue
        with open(file_path, "r") as fdata:
            for line in fdata:
                line_splits = line.strip().split("\t")
                if len(line_splits) < use_col: continue
                for w in line_splits[use_col].split():
                    values[w] += 1

    with open(save_path, "w") as f:
        for w in insert_extra_words:
            f.write("%s\t-1\n" % (w))

        for v, count in sorted(
                values.iteritems(), key=lambda x: x[1], reverse=True):
            if count < cutoff_fre:
                break
            f.write("%s\t%d\n" % (v, count))


def load_dict(dict_path):
    return dict((line.strip().split("\t")[0], idx)
                for idx, line in enumerate(open(dict_path, "r").readlines()))


def load_reverse_dict(dict_path):
    return dict((idx, line.strip().split("\t")[0])
                for idx, line in enumerate(open(dict_path, "r").readlines()))