diff --git a/fluid/PaddleNLP/text_classification/async_executor/README.md b/fluid/PaddleNLP/text_classification/async_executor/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c41e3e238704be735bf80db8012b06b14623be99 --- /dev/null +++ b/fluid/PaddleNLP/text_classification/async_executor/README.md @@ -0,0 +1,131 @@ +# 文本分类 + +以下是本例的简要目录结构及说明: + +```text +. +|-- README.md # README +|-- data_generator # IMDB数据集生成工具 +| |-- IMDB.py # 在data_generator.py基础上扩展IMDB数据集处理逻辑 +| |-- build_raw_data.py # IMDB数据预处理,其产出被splitfile.py读取。格式:word word ... | label +| |-- data_generator.py # 与AsyncExecutor配套的数据生成工具框架 +| `-- splitfile.py # 将build_raw_data.py生成的文件切分,其产出被IMDB.py读取 +|-- data_generator.sh # IMDB数据集生成工具入口 +|-- data_reader.py # 预测脚本使用的数据读取工具 +|-- infer.py # 预测脚本 +`-- train.py # 训练脚本 +``` + +## 简介 + +本目录包含用fluid.AsyncExecutor训练文本分类任务的脚本。网络模型定义沿用自父目录nets.py + +## 训练 + +1. 运行命令 `sh data_generator.sh`,下载IMDB数据集,并转化成适合AsyncExecutor读取的训练数据 +2. 运行命令 `python train.py bow` 开始训练模型。 + ```python + python train.py bow # bow指定网络结构,可替换成cnn, lstm, gru + ``` + +3. (可选)想自定义网络结构,需在[nets.py](../nets.py)中自行添加,并设置[train.py](./train.py)中的相应参数。 + ```python + def train(train_reader, # 训练数据 + word_dict, # 数据字典 + network, # 模型配置 + use_cuda, # 是否用GPU + parallel, # 是否并行 + save_dirname, # 保存模型路径 + lr=0.2, # 学习率大小 + batch_size=128, # 每个batch的样本数 + pass_num=30): # 训练的轮数 + ``` + +## 训练结果示例 + +```text +pass_id: 0 pass_time_cost 4.723438 +pass_id: 1 pass_time_cost 3.867186 +pass_id: 2 pass_time_cost 4.490111 +pass_id: 3 pass_time_cost 4.573296 +pass_id: 4 pass_time_cost 4.180547 +pass_id: 5 pass_time_cost 4.214476 +pass_id: 6 pass_time_cost 4.520387 +pass_id: 7 pass_time_cost 4.149485 +pass_id: 8 pass_time_cost 3.821354 +pass_id: 9 pass_time_cost 5.136178 +pass_id: 10 pass_time_cost 4.137318 +pass_id: 11 pass_time_cost 3.943429 +pass_id: 12 pass_time_cost 3.766478 +pass_id: 13 pass_time_cost 4.235983 +pass_id: 14 pass_time_cost 4.796462 +pass_id: 15 pass_time_cost 4.668116 +pass_id: 16 pass_time_cost 4.373798 +pass_id: 17 pass_time_cost 4.298131 +pass_id: 18 pass_time_cost 4.260021 +pass_id: 19 pass_time_cost 4.244411 +pass_id: 20 pass_time_cost 3.705138 +pass_id: 21 pass_time_cost 3.728070 +pass_id: 22 pass_time_cost 3.817919 +pass_id: 23 pass_time_cost 4.698598 +pass_id: 24 pass_time_cost 4.859262 +pass_id: 25 pass_time_cost 5.725732 +pass_id: 26 pass_time_cost 5.102599 +pass_id: 27 pass_time_cost 3.876582 +pass_id: 28 pass_time_cost 4.762538 +pass_id: 29 pass_time_cost 3.797759 +``` +与fluid.Executor不同,AsyncExecutor在每个pass结束不会将accuracy打印出来。为了观察训练过程,可以将fluid.AsyncExecutor.run()方法的Debug参数设为True,这样每个pass结束会把参数指定的fetch variable打印出来: + + +``` +async_executor.run( + main_program, + dataset, + filelist, + thread_num, + [acc], + debug=True) +``` + +## 预测 + +1. 运行命令 `python infer.py bow_model`, 开始预测。 + ```python + python infer.py bow_model # bow_model指定需要导入的模型 + ``` + +## 预测结果示例 +```text +model_path: bow_model/epoch0.model, avg_acc: 0.882600 +model_path: bow_model/epoch1.model, avg_acc: 0.887920 +model_path: bow_model/epoch2.model, avg_acc: 0.886920 +model_path: bow_model/epoch3.model, avg_acc: 0.884720 +model_path: bow_model/epoch4.model, avg_acc: 0.879760 +model_path: bow_model/epoch5.model, avg_acc: 0.876920 +model_path: bow_model/epoch6.model, avg_acc: 0.874160 +model_path: bow_model/epoch7.model, avg_acc: 0.872000 +model_path: bow_model/epoch8.model, avg_acc: 0.870360 +model_path: bow_model/epoch9.model, avg_acc: 0.868480 +model_path: bow_model/epoch10.model, avg_acc: 0.867240 +model_path: bow_model/epoch11.model, avg_acc: 0.866200 +model_path: bow_model/epoch12.model, avg_acc: 0.865560 +model_path: bow_model/epoch13.model, avg_acc: 0.865160 +model_path: bow_model/epoch14.model, avg_acc: 0.864480 +model_path: bow_model/epoch15.model, avg_acc: 0.864240 +model_path: bow_model/epoch16.model, avg_acc: 0.863800 +model_path: bow_model/epoch17.model, avg_acc: 0.863520 +model_path: bow_model/epoch18.model, avg_acc: 0.862760 +model_path: bow_model/epoch19.model, avg_acc: 0.862680 +model_path: bow_model/epoch20.model, avg_acc: 0.862240 +model_path: bow_model/epoch21.model, avg_acc: 0.862280 +model_path: bow_model/epoch22.model, avg_acc: 0.862080 +model_path: bow_model/epoch23.model, avg_acc: 0.861560 +model_path: bow_model/epoch24.model, avg_acc: 0.861280 +model_path: bow_model/epoch25.model, avg_acc: 0.861160 +model_path: bow_model/epoch26.model, avg_acc: 0.861080 +model_path: bow_model/epoch27.model, avg_acc: 0.860920 +model_path: bow_model/epoch28.model, avg_acc: 0.860800 +model_path: bow_model/epoch29.model, avg_acc: 0.860760 +``` +注:过拟合导致acc持续下降,请忽略 diff --git a/fluid/PaddleNLP/text_classification/async_executor/data_generator.sh b/fluid/PaddleNLP/text_classification/async_executor/data_generator.sh new file mode 100644 index 0000000000000000000000000000000000000000..ed58befa748a2f068c21fd8394df7107843e724c --- /dev/null +++ b/fluid/PaddleNLP/text_classification/async_executor/data_generator.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +pushd . +cd ./data_generator + +# wget "http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz" +if [ ! -f aclImdb_v1.tar.gz ]; then + wget "http://10.64.74.104:8080/paddle/dataset/imdb/aclImdb_v1.tar.gz" +fi +tar zxvf aclImdb_v1.tar.gz + +mkdir train_data +python build_raw_data.py train | python splitfile.py 12 train_data + +mkdir test_data +python build_raw_data.py test | python splitfile.py 12 test_data + +/opt/python27/bin/python IMDB.py train_data +/opt/python27/bin/python IMDB.py test_data + +mv ./output_dataset/train_data ../ +mv ./output_dataset/test_data ../ +cp aclImdb/imdb.vocab ../ + +rm -rf ./output_dataset +rm -rf train_data +rm -rf test_data +rm -rf aclImdb +popd diff --git a/fluid/PaddleNLP/text_classification/async_executor/data_generator/IMDB.py b/fluid/PaddleNLP/text_classification/async_executor/data_generator/IMDB.py new file mode 100644 index 0000000000000000000000000000000000000000..ace13d98b9ecf2eb3ee5abd13a0f9a26f739c84d --- /dev/null +++ b/fluid/PaddleNLP/text_classification/async_executor/data_generator/IMDB.py @@ -0,0 +1,46 @@ +import re +import os, sys +sys.path.append(os.path.abspath(os.path.join('..'))) +from data_generator import MultiSlotDataGenerator + + +class IMDbDataGenerator(MultiSlotDataGenerator): + def load_resource(self, dictfile): + self._vocab = {} + wid = 0 + with open(dictfile) as f: + for line in f: + self._vocab[line.strip()] = wid + wid += 1 + self._unk_id = len(self._vocab) + self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))') + + def process(self, line): + send = '|'.join(line.split('|')[:-1]).lower().replace("
", + " ").strip() + label = [int(line.split('|')[-1])] + + words = [x for x in self._pattern.split(send) if x and x != " "] + feas = [ + self._vocab[x] if x in self._vocab else self._unk_id for x in words + ] + + return ("words", feas), ("label", label) + + +imdb = IMDbDataGenerator() +imdb.load_resource("aclImdb/imdb.vocab") + +# data from files +file_names = os.listdir(sys.argv[1]) +filelist = [] +for i in range(0, len(file_names)): + filelist.append(os.path.join(sys.argv[1], file_names[i])) + +line_limit = 2500 +process_num = 24 +imdb.run_from_files( + filelist=filelist, + line_limit=line_limit, + process_num=process_num, + output_dir=('output_dataset/%s' % (sys.argv[1]))) diff --git a/fluid/PaddleNLP/text_classification/async_executor/data_generator/data_generator.py b/fluid/PaddleNLP/text_classification/async_executor/data_generator/data_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..bf1236a9a9dc2eb9d5a78a3096007b0a1b683181 --- /dev/null +++ b/fluid/PaddleNLP/text_classification/async_executor/data_generator/data_generator.py @@ -0,0 +1,494 @@ +import os +import sys +import multiprocessing +__all__ = ['MultiSlotDataGenerator'] + + +class DataGenerator(object): + def __init__(self): + self._proto_info = None + + def _set_filelist(self, filelist): + if not isinstance(filelist, list) and not isinstance(filelist, tuple): + raise ValueError("filelist%s must be in list or tuple type" % + type(filelist)) + if not filelist: + raise ValueError("filelist can not be empty") + self._filelist = filelist + + def _set_process_num(self, process_num): + if not isinstance(process_num, int): + raise ValueError("process_num%s must be in int type" % + type(process_num)) + if process_num < 1: + raise ValueError("process_num can not less than 1") + self._process_num = process_num + + def _set_line_limit(self, line_limit): + if not isinstance(line_limit, int): + raise ValueError("line_limit%s must be in int type" % + type(line_limit)) + if line_limit < 1: + raise ValueError("line_limit can not less than 1") + self._line_limit = line_limit + + def _set_output_dir(self, output_dir): + if not isinstance(output_dir, str): + raise ValueError("output_dir%s must be in str type" % + type(output_dir)) + if not output_dir: + raise ValueError("output_dir can not be empty") + self._output_dir = output_dir + + def _set_output_prefix(self, output_prefix): + if not isinstance(output_prefix, str): + raise ValueError("output_prefix%s must be in str type" % + type(output_prefix)) + self._output_prefix = output_prefix + + def _set_output_fill_digit(self, output_fill_digit): + if not isinstance(output_fill_digit, int): + raise ValueError("output_fill_digit%s must be in int type" % + type(output_fill_digit)) + if output_fill_digit < 1: + raise ValueError("output_fill_digit can not less than 1") + self._output_fill_digit = output_fill_digit + + def _set_proto_filename(self, proto_filename): + if not isinstance(proto_filename, str): + raise ValueError("proto_filename%s must be in str type" % + type(proto_filename)) + if not proto_filename: + raise ValueError("proto_filename can not be empty") + self._proto_filename = proto_filename + + def _print_info(self): + ''' + Print the configuration information + (Called only in the run_from_stdin function). + ''' + sys.stderr.write("=" * 16 + " config " + "=" * 16 + "\n") + sys.stderr.write(" filelist size: %d\n" % len(self._filelist)) + sys.stderr.write(" process num: %d\n" % self._process_num) + sys.stderr.write(" line limit: %d\n" % self._line_limit) + sys.stderr.write(" output dir: %s\n" % self._output_dir) + sys.stderr.write(" output prefix: %s\n" % self._output_prefix) + sys.stderr.write(" output fill digit: %d\n" % self._output_fill_digit) + sys.stderr.write(" proto filename: %s\n" % self._proto_filename) + sys.stderr.write("==== This may take a few minutes... ====\n") + + def _get_output_filename(self, output_index, lock=None): + ''' + This function is used to get the name of the output file and + update output_index. + Args: + output_index(manager.Value(i)): the index of output file. + lock(manager.Lock): The lock for processes safe. + Return: + Return the name(string) of output file. + ''' + if lock is not None: lock.acquire() + file_index = output_index.value + output_index.value += 1 + if lock is not None: lock.release() + filename = os.path.join(self._output_dir, self._output_prefix) \ + + str(file_index).zfill(self._output_fill_digit) + sys.stderr.write("[%d] write data to file: %s\n" % + (os.getpid(), filename)) + return filename + + def run_from_stdin(self, + is_local=True, + hadoop_host=None, + hadoop_ugi=None, + proto_path=None, + proto_filename="data_feed.proto"): + ''' + This function reads the data row from stdin, parses it with the + process function, and further parses the return value of the + process function with the _gen_str function. The parsed data will + be wrote to stdout and the corresponding protofile will be + generated. If local is set to False, the protofile will be + uploaded to hadoop. + Args: + is_local(bool): Whether to execute locally. If it is False, the + protofile will be uploaded to hadoop. The + default value is True. + hadoop_host(str): The host name of the hadoop. It should be + in this format: "hdfs://${HOST}:${PORT}". + hadoop_ugi(str): The ugi of the hadoop. It should be in this + format: "${USERNAME},${PASSWORD}". + proto_path(str): The hadoop path you want to upload the + protofile to. + proto_filename(str): The name of protofile. The default value + is "data_feed.proto". It is not + recommended to modify it. + ''' + if is_local: + print \ +'''\033[1;34m======================================================= + Pay attention to that the version of Python in Hadoop + may inconsistent with local version. Please check the + Python version of Hadoop to ensure that it is >= 2.7. +=======================================================\033[0m''' + else: + if hadoop_ugi is None or \ + hadoop_host is None or \ + proto_path is None: + raise ValueError( + "pls set hadoop_ugi, hadoop_host, and proto_path") + self._set_proto_filename(proto_filename) + for line in sys.stdin: + user_parsed_line = self.process(line) + sys.stdout.write(self._gen_str(user_parsed_line)) + if self._proto_info is not None: + # maybe some task do not catch files + with open(self._proto_filename, "w") as f: + f.write(self._get_proto_desc(self._proto_info)) + if is_local == False: + cmd = "$HADOOP_HOME/bin/hadoop fs" \ + + " -Dhadoop.job.ugi=" + hadoop_ugi \ + + " -Dfs.default.name=" + hadoop_host \ + + " -put " + self._proto_filename + " " + proto_path + os.system(cmd) + + def run_from_files(self, + filelist, + line_limit, + process_num=1, + output_dir="./output_dataset", + output_prefix="part-", + output_fill_digit=8, + proto_filename="data_feed.proto"): + ''' + This function will run process_num processes to process the files + in the filelist. It will create the output data folder(output_dir) + in the current directory, and write the processed data into the + output_dir folder(each file line_limit data, the prefix of filename + is output_prefix, the suffix of filename is output_fill_digit + numbers). And the proto_info is generated at the same time. the + name of proto file will be proto_filename. + Args: + filelist(list or tuple): Files that need to be processed. + line_limit(int): Maximum number of data stored per file. + process_num(int): Number of processes running simultaneously. + output_dir(str): The name of the folder where the output + data file is stored. + output_prefix(str): The prefix of output data file. + output_fill_digit(int): The number of suffix numbers of the + output data file. + proto_filename(str): The name of protofile. + ''' + self._set_filelist(filelist) + self._set_line_limit(line_limit) + self._set_process_num(min(process_num, len(filelist))) + self._set_output_dir(output_dir) + self._set_output_prefix(output_prefix) + self._set_output_fill_digit(output_fill_digit) + self._set_proto_filename(proto_filename) + self._print_info() + + if not os.path.exists(self._output_dir): + os.makedirs(self._output_dir) + elif not os.path.isdir(self._output_dir): + raise ValueError("%s is not a directory" % self._output_dir) + + processes = multiprocessing.Pool() + manager = multiprocessing.Manager() + output_index = manager.Value('i', 0) + file_queue = manager.Queue() + lock = manager.Lock() + remaining_queue = manager.Queue() + for file in self._filelist: + file_queue.put(file) + info_result = [] + for i in range(self._process_num): + info_result.append(processes.apply_async(subprocess_wrapper, \ + (self, file_queue, remaining_queue, output_index, lock, ))) + processes.close() + processes.join() + + infos = [ + result.get() for result in info_result if result.get() is not None + ] + proto_info = self._combine_infos(infos) + with open(os.path.join(self._output_dir, self._proto_filename), + "w") as f: + f.write(self._get_proto_desc(proto_info)) + + while not remaining_queue.empty(): + with open(self._get_output_filename(output_index), "w") as f: + for i in range(min(self._line_limit, remaining_queue.qsize())): + f.write(remaining_queue.get(False)) + + def _subprocess(self, file_queue, remaining_queue, output_index, lock): + ''' + This function will be called by multiple processes. It is used to + continuously fetch files from file_queue, using process() function + (defined by user) and _gen_str() function(defined by concrete classes) + to process data in units of rows. Write the processed data to the + file(each file will be self._line_limit line). If the file in the + file_queue has been consumed, but the file is not full, the data + that is less than the self._line_limit line will be stored in the + remaining_queue. + Args: + file_queue(manager.Queue): The queue contains all the file + names to be processed. + remaining_queue(manager.Queue): The queue contains the data that + is less than the self._line_limit + line. + output_index(manager.Value(i)): The index(suffix) of the + output file. + lock(manager.Lock): The lock for processes safe. + Returns: + Return a proto_info which can be translated into a proto string. + ''' + buffer = [] + while not file_queue.empty(): + try: + filename = file_queue.get(False) + except: # file_queue empty + break + with open(filename, 'r') as f: + for line in f: + buffer.append(self._gen_str(self.process(line))) + if len(buffer) == self._line_limit: + with open( + self._get_output_filename(output_index, lock), + "w") as wf: + for x in buffer: + wf.write(x) + buffer = [] + if buffer: + for x in buffer: + remaining_queue.put(x) + return self._proto_info + + def _gen_str(self, line): + ''' + Further processing the output of the process() function rewritten by + user, outputting data that can be directly read by the datafeed,and + updating proto_info infomation. + Args: + line(str): the output of the process() function rewritten by user. + Returns: + Return a string data that can be read directly by the datafeed. + ''' + raise NotImplementedError( + "pls use MultiSlotDataGenerator or PairWiseDataGenerator") + + def _combine_infos(self, infos): + ''' + This function is used to merge proto_info information from different + processes. In general, the proto_info of each process is consistent. + Args: + infos(list): the list of proto_infos from different processes. + Returns: + Return a unified proto_info. + ''' + raise NotImplementedError( + "pls use MultiSlotDataGenerator or PairWiseDataGenerator") + + def _get_proto_desc(self, proto_info): + ''' + This function outputs the string of the proto file(can be directly + written to the file) according to the proto_info information. + Args: + proto_info: The proto information used to generate the proto + string. The type of the variable will be determined + by the subclass. In the MultiSlotDataGenerator, + proto_info variable is a list of tuple. + Returns: + Returns a string of the proto file. + ''' + raise NotImplementedError( + "pls use MultiSlotDataGenerator or PairWiseDataGenerator") + + def process(self, line): + ''' + This function needs to be overridden by the user to process the + original data row into a list or tuple. + Args: + line(str): the original data row + Returns: + Returns the data processed by the user. + The data format is list or tuple: + [(name, [feasign, ...]), ...] + or ((name, [feasign, ...]), ...) + + For example: + [("words", [1926, 08, 17]), ("label", [1])] + or (("words", [1926, 08, 17]), ("label", [1])) + Note: + The type of feasigns must be in int or float. Once the float + element appears in the feasign, the type of that slot will be + processed into a float. + ''' + raise NotImplementedError( + "pls rewrite this function to return a list or tuple: " + + "[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)") + + +def subprocess_wrapper(instance, file_queue, remaining_queue, output_index, + lock): + ''' + In order to use the class function as a process, you need to wrap it. + ''' + return instance._subprocess(file_queue, remaining_queue, output_index, lock) + + +class MultiSlotDataGenerator(DataGenerator): + def _combine_infos(self, infos): + ''' + This function is used to merge proto_info information from different + processes. In general, the proto_info of each process is consistent. + The type of input infos is list, and the type of element of infos is + tuple. The format of element of infos will be (name, type). + Args: + infos(list): the list of proto_infos from different processes. + Returns: + Return a unified proto_info. + Note: + This function is only called by the run_from_files function, so + when using the run_from_stdin function(usually used for hadoop), + the output of the process function(rewritten by the user) does + not allow that the same field to have both float and int type + values. + ''' + proto_info = infos[0] + for info in infos: + for index, slot in enumerate(info): + name, type = slot + if name != proto_info[index][0]: + raise ValueError( + "combine infos error, pls contact the maintainer of this code~" + ) + if type == "float" and proto_info[index][1] == "uint64": + proto_info[index] = (name, type) + return proto_info + + def _get_proto_desc(self, proto_info): + ''' + Generate a string of proto file based on the proto_info information. + + The proto_info will be a list of tuples: + >>> [(Name, Type), ...] + + The string of proto file will be in this format: + >>> name: "MultiSlotDataFeed" + >>> batch_size: 32 + >>> multi_slot_desc { + >>> slots { + >>> name: Name + >>> type: Type + >>> is_dense: false + >>> is_used: false + >>> } + >>> } + Args: + proto_info(list): The proto information used to generate the + proto string. + Returns: + Returns a string of the proto file. + ''' + proto_str = "name: \"MultiSlotDataFeed\"\n" \ + + "batch_size: 32\nmulti_slot_desc {\n" + for elem in proto_info: + proto_str += " slots {\n" \ + + " name: \"%s\"\n" % elem[0]\ + + " type: \"%s\"\n" % elem[1]\ + + " is_dense: false\n" \ + + " is_used: false\n" \ + + " }\n" + proto_str += "}" + return proto_str + + def _gen_str(self, line): + ''' + Further processing the output of the process() function rewritten by + user, outputting data that can be directly read by the MultiSlotDataFeed, + and updating proto_info infomation. + The input line will be in this format: + >>> [(name, [feasign, ...]), ...] + >>> or ((name, [feasign, ...]), ...) + The output will be in this format: + >>> [ids_num id1 id2 ...] ... + The proto_info will be in this format: + >>> [(name, type), ...] + + For example, if the input is like this: + >>> [("words", [1926, 08, 17]), ("label", [1])] + >>> or (("words", [1926, 08, 17]), ("label", [1])) + the output will be: + >>> 3 1234 2345 3456 1 1 + the proto_info will be: + >>> [("words", "uint64"), ("label", "uint64")] + Args: + line(str): the output of the process() function rewritten by user. + Returns: + Return a string data that can be read directly by the MultiSlotDataFeed. + ''' + if not isinstance(line, list) and not isinstance(line, tuple): + raise ValueError( + "the output of process() must be in list or tuple type") + output = "" + + if self._proto_info is None: + self._proto_info = [] + for item in line: + name, elements = item + if not isinstance(name, str): + raise ValueError("name%s must be in str type" % type(name)) + if not isinstance(elements, list): + raise ValueError("elements%s must be in list type" % + type(elements)) + if not elements: + raise ValueError( + "the elements of each field can not be empty, you need padding it in process()." + ) + self._proto_info.append((name, "uint64")) + if output: + output += " " + output += str(len(elements)) + for elem in elements: + if isinstance(elem, float): + self._proto_info[-1] = (name, "float") + elif not isinstance(elem, int) and not isinstance(elem, + long): + raise ValueError( + "the type of element%s must be in int or float" % + type(elem)) + output += " " + str(elem) + else: + if len(line) != len(self._proto_info): + raise ValueError( + "the complete field set of two given line are inconsistent.") + for index, item in enumerate(line): + name, elements = item + if not isinstance(name, str): + raise ValueError("name%s must be in str type" % type(name)) + if not isinstance(elements, list): + raise ValueError("elements%s must be in list type" % + type(elements)) + if not elements: + raise ValueError( + "the elements of each field can not be empty, you need padding it in process()." + ) + if name != self._proto_info[index][0]: + raise ValueError( + "the field name of two given line are not match: require<%s>, get<%d>." + % (self._proto_info[index][0], name)) + if output: + output += " " + output += str(len(elements)) + for elem in elements: + if self._proto_info[index][1] != "float": + if isinstance(elem, float): + self._proto_info[index] = (name, "float") + elif not isinstance(elem, int) and not isinstance(elem, + long): + raise ValueError( + "the type of element%s must be in int or float" + % type(elem)) + output += " " + str(elem) + return output + "\n" diff --git a/fluid/PaddleNLP/text_classification/async_executor/data_generator/splitfile.py b/fluid/PaddleNLP/text_classification/async_executor/data_generator/splitfile.py new file mode 100644 index 0000000000000000000000000000000000000000..273d5cea889e5b08086809883aff6ea934ee4d43 --- /dev/null +++ b/fluid/PaddleNLP/text_classification/async_executor/data_generator/splitfile.py @@ -0,0 +1,16 @@ +""" +Split file into parts +""" +import sys +import os +block = int(sys.argv[1]) +datadir = sys.argv[2] +file_list = [] +for i in range(block): + file_list.append(open(datadir + "/part-" + str(i), "w")) +id_ = 0 +for line in sys.stdin: + file_list[id_ % block].write(line) + id_ += 1 +for f in file_list: + f.close() diff --git a/fluid/PaddleNLP/text_classification/async_executor/data_reader.py b/fluid/PaddleNLP/text_classification/async_executor/data_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..a8bc7780ba7ad6c8a5ce7f2af795b0098baf60ac --- /dev/null +++ b/fluid/PaddleNLP/text_classification/async_executor/data_reader.py @@ -0,0 +1,36 @@ +import sys +import os +import paddle + + +def parse_fields(fields): + words_width = int(fields[0]) + words = fields[1:1 + words_width] + label = fields[-1] + + return words, label + + +def imdb_data_feed_reader(data_dir, batch_size, buf_size): + """ + Data feed reader for IMDB dataset. + This data set has been converted from original format to a format suitable + for AsyncExecutor + See data.proto for data format + """ + + def reader(): + for file in os.listdir(data_dir): + if file.endswith('.proto'): + continue + + with open(os.path.join(data_dir, file), 'r') as f: + for line in f: + fields = line.split(' ') + words, label = parse_fields(fields) + yield words, label + + test_reader = paddle.batch( + paddle.reader.shuffle( + reader, buf_size=buf_size), batch_size=batch_size) + return test_reader diff --git a/fluid/PaddleNLP/text_classification/async_executor/infer.py b/fluid/PaddleNLP/text_classification/async_executor/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..15d6553f2ec02eff2f7eb0736e4daf76daefdc63 --- /dev/null +++ b/fluid/PaddleNLP/text_classification/async_executor/infer.py @@ -0,0 +1,65 @@ +import os +import sys +import time +import unittest +import contextlib +import numpy as np + +import paddle +import paddle.fluid as fluid + +import data_reader + + +def infer(test_reader, use_cuda, model_path=None): + """ + inference function + """ + if model_path is None: + print(str(model_path) + " cannot be found") + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(model_path, exe) + + total_acc = 0.0 + total_count = 0 + for data in test_reader(): + acc = exe.run(inference_program, + feed=utils.data2tensor(data, place), + fetch_list=fetch_targets, + return_numpy=True) + total_acc += acc[0] * len(data) + total_count += len(data) + + avg_acc = total_acc / total_count + print("model_path: %s, avg_acc: %f" % (model_path, avg_acc)) + + +if __name__ == "__main__": + if __package__ is None: + from os import sys, path + sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + import utils + + batch_size = 128 + model_path = sys.argv[1] + test_data_dirname = 'test_data' + + if len(sys.argv) == 3: + test_data_dirname = sys.argv[2] + + test_reader = data_reader.imdb_data_feed_reader( + 'test_data', batch_size, buf_size=500000) + + models = os.listdir(model_path) + for i in range(0, len(models)): + epoch_path = "epoch" + str(i) + ".model" + epoch_path = os.path.join(model_path, epoch_path) + infer(test_reader, use_cuda=False, model_path=epoch_path) diff --git a/fluid/PaddleNLP/text_classification/async_executor/train.py b/fluid/PaddleNLP/text_classification/async_executor/train.py new file mode 100644 index 0000000000000000000000000000000000000000..40a3a8e5f903ddc23400e7a4063e022685110bdf --- /dev/null +++ b/fluid/PaddleNLP/text_classification/async_executor/train.py @@ -0,0 +1,98 @@ +import os +import sys +import time +import multiprocessing + +import paddle +import paddle.fluid as fluid + + +def train(network, dict_dim, lr, save_dirname, training_data_dirname, pass_num, + thread_num, batch_size): + file_names = os.listdir(training_data_dirname) + filelist = [] + for i in range(0, len(file_names)): + if file_names[i] == 'data_feed.proto': + continue + filelist.append(os.path.join(training_data_dirname, file_names[i])) + + dataset = fluid.DataFeedDesc( + os.path.join(training_data_dirname, 'data_feed.proto')) + dataset.set_batch_size( + batch_size) # datafeed should be assigned a batch size + dataset.set_use_slots(['words', 'label']) + + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost, acc, prediction = network(data, label, dict_dim) + optimizer = fluid.optimizer.Adagrad(learning_rate=lr) + opt_ops, weight_and_grad = optimizer.minimize(avg_cost) + + startup_program = fluid.default_startup_program() + main_program = fluid.default_main_program() + + place = fluid.CPUPlace() + executor = fluid.Executor(place) + executor.run(startup_program) + + async_executor = fluid.AsyncExecutor(place) + for i in range(pass_num): + pass_start = time.time() + async_executor.run(main_program, + dataset, + filelist, + thread_num, [acc], + debug=False) + print('pass_id: %u pass_time_cost %f' % (i, time.time() - pass_start)) + fluid.io.save_inference_model('%s/epoch%d.model' % (save_dirname, i), + [data.name, label.name], [acc], executor) + + +if __name__ == "__main__": + if __package__ is None: + from os import sys, path + sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + + from nets import bow_net, cnn_net, lstm_net, gru_net + from utils import load_vocab + + batch_size = 4 + lr = 0.002 + pass_num = 30 + save_dirname = "" + thread_num = multiprocessing.cpu_count() + + if sys.argv[1] == "bow": + network = bow_net + batch_size = 128 + save_dirname = "bow_model" + elif sys.argv[1] == "cnn": + network = cnn_net + lr = 0.01 + save_dirname = "cnn_model" + elif sys.argv[1] == "lstm": + network = lstm_net + lr = 0.05 + save_dirname = "lstm_model" + elif sys.argv[1] == "gru": + network = gru_net + batch_size = 128 + lr = 0.05 + save_dirname = "gru_model" + + training_data_dirname = 'train_data/' + if len(sys.argv) == 3: + training_data_dirname = sys.argv[2] + + if len(sys.argv) == 4: + if thread_num >= int(sys.argv[3]): + thread_num = int(sys.argv[3]) + + vocab = load_vocab('imdb.vocab') + dict_dim = len(vocab) + + train(network, dict_dim, lr, save_dirname, training_data_dirname, pass_num, + thread_num, batch_size)