utils.py 2.0 KB
Newer Older
P
peterzhang2029 已提交
1
import os
P
peterzhang2029 已提交
2
from collections import defaultdict
P
peterzhang2029 已提交
3 4 5 6 7 8 9


def get_file_list(image_file_list):
    '''
    Generate the file list for training and testing data.
    
    :param image_file_list: The path of the file which contains
P
peterzhang2029 已提交
10
                            path list of image files.
P
peterzhang2029 已提交
11 12 13 14 15 16 17 18 19
    :type image_file_list: str
    '''
    dirname = os.path.dirname(image_file_list)
    path_list = []
    with open(image_file_list) as f:
        for line in f:
            line_split = line.strip().split(',', 1)
            filename = line_split[0].strip()
            path = os.path.join(dirname, filename)
P
peterzhang2029 已提交
20 21 22
            label = line_split[1][2:-1].strip()
            if label:
                path_list.append((path, label))
P
peterzhang2029 已提交
23 24

    return path_list
P
peterzhang2029 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69


def build_label_dict(file_list, save_path):
    """
    Build label dictionary from training data.
    
    :param file_list: The list which contains the labels 
                      of training data.
    :type file_list: list
    :params save_path: The path where the label dictionary will be saved.
    :type save_path: str
    """
    values = defaultdict(int)
    for path, label in file_list:
        for c in label:
            if c:
                values[c] += 1

    values['<unk>'] = 0
    with open(save_path, "w") as f:
        for v, count in sorted(
                values.iteritems(), key=lambda x: x[1], reverse=True):
            f.write("%s\t%d\n" % (v, count))


def load_dict(dict_path):
    """
    Load label dictionary from the dictionary path.
    
    :param dict_path: The path of word dictionary.
    :type dict_path: str
    """
    return dict((line.strip().split("\t")[0], idx)
                for idx, line in enumerate(open(dict_path, "r").readlines()))


def load_reverse_dict(dict_path):
    """
    Load the reversed label dictionary from dictionary path.
    
    :param dict_path: The path of word dictionary.
    :type dict_path: str
    """
    return dict((idx, line.strip().split("\t")[0])
                for idx, line in enumerate(open(dict_path, "r").readlines()))