#!/usr/bin/env python # -*- coding: utf-8 -*- import logging import os import argparse from collections import defaultdict logger = logging.getLogger("paddle") logger.setLevel(logging.INFO) def parse_train_cmd(): parser = argparse.ArgumentParser( description="PaddlePaddle text classification demo") parser.add_argument( "--nn_type", type=str, help="define which type of network to use, available: [dnn, cnn]", default="dnn") parser.add_argument( "--train_data_dir", type=str, required=False, help=("path of training dataset (default: None). " "if this parameter is not set, " "paddle.dataset.imdb will be used."), default=None) parser.add_argument( "--test_data_dir", type=str, required=False, help=("path of testing dataset (default: None). " "if this parameter is not set, " "paddle.dataset.imdb will be used."), default=None) parser.add_argument( "--word_dict", type=str, required=False, help=("path of word dictionary (default: None)." "if this parameter is not set, paddle.dataset.imdb will be used." "if this parameter is set, but the file does not exist, " "word dictionay will be built from " "the training data automatically."), default=None) parser.add_argument( "--label_dict", type=str, required=False, help=("path of label dictionay (default: None)." "if this parameter is not set, paddle.dataset.imdb will be used." "if this parameter is set, but the file does not exist, " "word dictionay will be built from " "the training data automatically."), default=None) parser.add_argument( "--batch_size", type=int, default=32, help="the number of training examples in one forward/backward pass") parser.add_argument( "--num_passes", type=int, default=10, help="number of passes to train") parser.add_argument( "--model_save_dir", type=str, required=False, help=("path to save the trained models."), default="models") return parser.parse_args() def build_dict(data_dir, save_path, use_col=0, cutoff_fre=0, insert_extra_words=[]): values = defaultdict(int) for file_name in os.listdir(data_dir): file_path = os.path.join(data_dir, file_name) if not os.path.isfile(file_path): continue with open(file_path, "r") as fdata: for line in fdata: line_splits = line.strip().split("\t") if len(line_splits) < use_col: continue for w in line_splits[use_col].split(): values[w] += 1 with open(save_path, "w") as f: for w in insert_extra_words: f.write("%s\t-1\n" % (w)) for v, count in sorted( values.iteritems(), key=lambda x: x[1], reverse=True): if count < cutoff_fre: break f.write("%s\t%d\n" % (v, count)) def load_dict(dict_path): return dict((line.strip().split("\t")[0], idx) for idx, line in enumerate(open(dict_path, "r").readlines())) def load_reverse_dict(dict_path): return dict((idx, line.strip().split("\t")[0]) for idx, line in enumerate(open(dict_path, "r").readlines()))