utils.py 1.1 KB
Newer Older
O
overlordmax 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import numpy as np 
import os
import paddle.fluid as fluid
import logging
from collections import defaultdict

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

all_field_id = ['101', '109_14', '110_14', '127_14', '150_14', '121', '122', '124', '125', '126', '127', '128', '129',
                '205', '206', '207', '210', '216', '508', '509', '702', '853', '301']
all_field_id_dict = defaultdict(int)
for i,field_id in enumerate(all_field_id):
    all_field_id_dict[field_id] = [False,i]

def get_dataset(inputs,files,batch_size,cpu_num):
    dataset = fluid.DatasetFactory().create_dataset()
    dataset.set_use_var(inputs)
    dataset.set_pipe_command("python dataset_generator.py")
    dataset.set_batch_size(batch_size)
    dataset.set_thread(int(cpu_num))
    file_list = [
        os.path.join(files, x) for x in os.listdir(files)
    ]
    logger.info("file list: {}".format(file_list))
    return dataset, file_list
    
def get_vocab_size(vocab_path):
    with open(vocab_path, "r") as rf:
        line = rf.readline()
        return int(line.strip()) + 1