# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ The file_reader converts raw corpus to input. """ import os import argparse import __future__ import io import glob from paddleslim.pantheon import Student import random import numpy as np import six def load_kv_dict(dict_path, reverse=False, delimiter="\t", key_func=None, value_func=None): """ Load key-value dict from file """ result_dict = {} for line in io.open(dict_path, "r", encoding='utf8'): terms = line.strip("\n").split(delimiter) if len(terms) != 2: continue if reverse: value, key = terms else: key, value = terms if key in result_dict: raise KeyError("key duplicated with [%s]" % (key)) if key_func: key = key_func(key) if value_func: value = value_func(value) result_dict[key] = value return result_dict class Dataset(object): """data reader""" def __init__(self, args, mode="train"): # read dict self.word2id_dict = load_kv_dict( args.word_dict_path, reverse=True, value_func=int) self.id2word_dict = load_kv_dict(args.word_dict_path) self.label2id_dict = load_kv_dict( args.label_dict_path, reverse=True, value_func=int) self.id2label_dict = load_kv_dict(args.label_dict_path) self.word_replace_dict = load_kv_dict(args.word_rep_dict_path) self._student = Student() self._student.register_teacher(in_address=args.in_address) self._student.start() self._know_desc = self._student.get_knowledge_desc() self._know_data_generator = self._student.get_knowledge_generator(batch_size=1, drop_last=False)() self._train_shuffle_buf_size = args.traindata_shuffle_buffer @property def vocab_size(self): """vocabuary size""" return max(self.word2id_dict.values()) + 1 @property def num_labels(self): """num_labels""" return max(self.label2id_dict.values()) + 1 def get_num_examples(self, filename): """num of line of file""" return sum(1 for line in io.open(filename, "r", encoding='utf8')) def word_to_ids(self, words): """convert word to word index""" word_ids = [] for word in words: word = self.word_replace_dict.get(word, word) if word not in self.word2id_dict: word = "OOV" word_id = self.word2id_dict[word] word_ids.append(word_id) return word_ids def label_to_ids(self, labels): """convert label to label index""" label_ids = [] for label in labels: if label not in self.label2id_dict: label = "O" label_id = self.label2id_dict[label] label_ids.append(label_id) return label_ids def file_reader(self, filename, max_seq_len=126, mode="train"): """ yield (word_idx, target_idx, teacher_emission) one by one from file, or yield (word_idx, ) in `infer` mode """ def wrapper(): invalid_samples = 0 fread = io.open(filename, "r", encoding="utf-8") if mode == "infer": for line in fread: words = line.strip() word_ids = self.word_to_ids(words) yield (word_ids[0:max_seq_len], ) elif mode == "test": headline = next(fread) headline = headline.strip().split('\t') assert len(headline) == 2 and headline[ 0] == "text_a" and headline[1] == "label" for line in fread: words, labels = line.strip("\n").split("\t") if len(words) < 1: continue word_ids = self.word_to_ids(words.split("\002")) label_ids = self.label_to_ids(labels.split("\002")) assert len(word_ids) == len(label_ids) yield word_ids[0:max_seq_len], label_ids[0:max_seq_len] else: headline = next(fread) headline = headline.strip().split('\t') assert len(headline) == 2 and headline[ 0] == "text_a" and headline[1] == "label" buf = [] for line in fread: words, labels = line.strip("\n").split("\t") if len(words) < 1: continue word_ids = self.word_to_ids(words.split("\002")) label_ids = self.label_to_ids(labels.split("\002")) if six.PY2: know_data = self._know_data_generator.next() else: know_data = self._know_data_generator.__next__() teacher_crf_decode = know_data["crf_decode"] if len(teacher_crf_decode.shape) == 1: teacher_crf_decode = np.reshape(teacher_crf_decode, [-1, 1]) teacher_seq_len = know_data["seq_lens"] assert len(word_ids) == len(label_ids) real_len = len(word_ids) if len(word_ids) < max_seq_len else max_seq_len if real_len == teacher_seq_len[0] - 2: teacher_crf_decode_range = teacher_crf_decode[0][1:teacher_seq_len[0]-1] teacher_crf_decode_range = np.reshape(teacher_crf_decode_range, [-1, 1]) buf.append([word_ids[0:max_seq_len], label_ids[0:max_seq_len], teacher_crf_decode_range]) #buf.append([word_ids[0:max_seq_len], label_ids[0:max_seq_len], teacher_crf_decode[0][1:teacher_seq_len[0]-1]]) if len(buf) > self._train_shuffle_buf_size: buf_ids = range(len(buf)) random.shuffle(buf_ids) for idx in buf_ids: yield buf[idx] buf = [] else: invalid_samples += 1 if len(buf) > 0: buf_ids = list(range(len(buf))) random.shuffle(buf_ids) for idx in buf_ids: yield buf[idx] print("invalid samples in one epoch: {}".format(invalid_samples)) fread.close() return wrapper if __name__ == "__main__": parser = argparse.ArgumentParser(__doc__) parser.add_argument( "--word_dict_path", type=str, default="./conf/word.dic", help="word dict") parser.add_argument( "--label_dict_path", type=str, default="./conf/tag.dic", help="label dict") parser.add_argument( "--word_rep_dict_path", type=str, default="./conf/q2b.dic", help="word replace dict") args = parser.parse_args() dataset = Dataset(args) # data_generator = dataset.file_reader("data/train.tsv") #for word_idx, target_idx in data_generator(): # print(word_idx, target_idx) # print(len(word_idx), len(target_idx)) # break