From ffa2568a42c81032e33f7fbc047af43816ec7935 Mon Sep 17 00:00:00 2001 From: Peng Li Date: Thu, 19 Oct 2017 20:07:27 +0800 Subject: [PATCH] fix style problem --- neural_seq_qa/reader.py | 81 +++++++++++++++++-------------- neural_seq_qa/test/test_reader.py | 9 ++-- 2 files changed, 49 insertions(+), 41 deletions(-) diff --git a/neural_seq_qa/reader.py b/neural_seq_qa/reader.py index 225cb074..e55e77b6 100644 --- a/neural_seq_qa/reader.py +++ b/neural_seq_qa/reader.py @@ -8,9 +8,10 @@ from datapoint import DataPoint, Evidence, EecommFeatures import utils from utils import logger -__all__ = ["Q_IDS", "E_IDS", "LABELS", "QE_COMM", "EE_COMM", - "Q_IDS_STR", "E_IDS_STR", "LABELS_STR", "QE_COMM_STR", "EE_COMM_STR", - "Settings", "create_reader"] +__all__ = [ + "Q_IDS", "E_IDS", "LABELS", "QE_COMM", "EE_COMM", "Q_IDS_STR", "E_IDS_STR", + "LABELS_STR", "QE_COMM_STR", "EE_COMM_STR", "Settings", "create_reader" +] # slot names Q_IDS_STR = "q_ids" @@ -25,7 +26,6 @@ LABELS = 2 QE_COMM = 3 EE_COMM = 4 - NO_ANSWER = "no_answer" @@ -33,6 +33,7 @@ class Settings(object): """ class for storing settings """ + def __init__(self, vocab, is_training, @@ -75,15 +76,23 @@ class Settings(object): elif label_schema == "BIO2": B, I, O1, O2 = 0, 1, 2, 3 else: - raise ValueError("label_schema should be BIO/BIO2") + raise ValueError("label_schema should be BIO/BIO2") self.B, self.I, self.O1, self.O2 = B, I, O1, O2 - self.label_map = {"B":B, "I":I, "O1":O1, "O2":O2, - "b":B, "i":I, "o1":O1, "o2":O2} + self.label_map = { + "B": B, + "I": I, + "O1": O1, + "O2": O2, + "b": B, + "i": I, + "o1": O1, + "o2": O2 + } self.label_num = len(set((B, I, O1, O2))) # id for OOV self.oov_id = 0 - + # set up random seed random.seed(seed) @@ -94,7 +103,7 @@ class Settings(object): logger.info("keep_first_b: %s", keep_first_b) logger.info("data reader random seed: %d", seed) - + class SampleStream(object): def __init__(self, filename, settings): self.filename = filename @@ -102,7 +111,7 @@ class SampleStream(object): def __iter__(self): return self.load_and_filter_samples(self.filename) - + def load_and_filter_samples(self, filename): def remove_extra_b(labels): if labels.count(self.settings.B) <= 1: return @@ -111,7 +120,7 @@ class SampleStream(object): # find the first B while i < len(labels) and labels[i] == self.settings.O1: i += 1 - i += 1 # skip B + i += 1 # skip B # skip the following Is while i < len(labels) and labels[i] == self.settings.I: i += 1 @@ -138,23 +147,22 @@ class SampleStream(object): # matches in training is_all_o1 = labels.count(self.settings.O1) == len(labels) if self.settings.is_training and is_all_o1 and not is_negative: - evidences[i] = None # dropped + evidences[i] = None # dropped continue if self.settings.keep_first_b: remove_extra_b(labels) evi[Evidence.GOLDEN_LABELS] = labels - def get_eecom_feats_list(cur_sample_is_negative, - eecom_feats_list, + def get_eecom_feats_list(cur_sample_is_negative, eecom_feats_list, evidences): if not self.settings.is_training: - return [item[EecommFeatures.EECOMM_FEATURES] \ - for item in eecom_feats_list] + return [item[EecommFeatures.EECOMM_FEATURES] \ + for item in eecom_feats_list] positive_eecom_feats_list = [] negative_eecom_feats_list = [] - + for eecom_feats_, other_evi in izip(eecom_feats_list, evidences): if not other_evi: continue @@ -174,7 +182,7 @@ class SampleStream(object): eecom_feats_list = positive_eecom_feats_list if negative_eecom_feats_list: eecom_feats_list += [negative_eecom_feats_list] - + return eecom_feats_list def process_tokens(data, tok_key): @@ -189,16 +197,15 @@ class SampleStream(object): qe_comm = evi[Evidence.QECOMM_FEATURES] sample_type = evi[Evidence.TYPE] - ret = [None] * 5 + ret = [None] * 5 ret[Q_IDS] = q_ids ret[E_IDS] = e_ids ret[LABELS] = labels ret[QE_COMM] = qe_comm eecom_feats_list = get_eecom_feats_list( - sample_type != Evidence.POSITIVE, - evi[Evidence.EECOMM_FEATURES_LIST], - evidences) + sample_type != Evidence.POSITIVE, + evi[Evidence.EECOMM_FEATURES_LIST], evidences) if not eecom_feats_list: return None else: @@ -217,7 +224,7 @@ class SampleStream(object): # convert question tokens to ids q_ids = process_tokens(data, DataPoint.Q_TOKENS) - + # process evidences evidences = data[DataPoint.EVIDENCES] filter_and_preprocess_evidences(evidences) @@ -226,13 +233,14 @@ class SampleStream(object): sample = process_evi(q_ids, evi, evidences) if sample: yield q_idx, sample, evi[Evidence.TYPE] + class DataReader(object): def __iter__(self): return self def _next(self): raise NotImplemented() - + def next(self): data_point = self._next() return self.post_process_sample(data_point) @@ -251,10 +259,7 @@ class DataReader(object): class TrainingDataReader(DataReader): - def __init__(self, - sample_stream, - negative_ratio, - hit_ans_negative_ratio): + def __init__(self, sample_stream, negative_ratio, hit_ans_negative_ratio): super(TrainingDataReader, self).__init__() self.positive_data = [] self.hit_ans_negative_data = [] @@ -308,8 +313,8 @@ class TrainingDataReader(DataReader): if len(self.positive_data) == 0: logger.fatal("zero positive sample") raise ValueError("zero positive sample") - - zero_hit = len(self.hit_ans_negative_data) == 0 + + zero_hit = len(self.hit_ans_negative_data) == 0 zero_other = len(self.other_negative_data) == 0 if zero_hit and zero_other: @@ -335,7 +340,7 @@ class TrainingDataReader(DataReader): self.p_idx = 0 self.p_idx += 1 - return self.positive_data[self.p_idx-1] + return self.positive_data[self.p_idx - 1] def _next_negative_data(self, idx, negative_data): if idx >= len(negative_data): @@ -352,16 +357,16 @@ class TrainingDataReader(DataReader): random.shuffle(bundle[0]) bundle[1] = 0 bundle[1] += 1 - return idx+1, bundle[0][bundle[1]-1] + return idx + 1, bundle[0][bundle[1] - 1] def next_hit_ans_negative_data(self): self.hit_idx, data = self._next_negative_data( - self.hit_idx, self.hit_ans_negative_data) + self.hit_idx, self.hit_ans_negative_data) return data def next_other_negative_data(self): self.other_idx, data = self._next_negative_data( - self.other_idx, self.other_negative_data) + self.other_idx, self.other_negative_data) return data def _next(self): @@ -387,16 +392,18 @@ class TestDataReader(DataReader): def create_reader(filename, settings, samples_per_pass=sys.maxint): if settings.is_training: training_reader = TrainingDataReader( - SampleStream(filename, settings), - settings.negative_sample_ratio, - settings.hit_ans_negative_sample_ratio) + SampleStream(filename, settings), settings.negative_sample_ratio, + settings.hit_ans_negative_sample_ratio) def wrapper(): for i, data in izip(xrange(samples_per_pass), training_reader): yield data + return wrapper else: + def wrapper(): sample_stream = SampleStream(filename, settings) return TestDataReader(sample_stream) + return wrapper diff --git a/neural_seq_qa/test/test_reader.py b/neural_seq_qa/test/test_reader.py index e9765785..2c3725b5 100644 --- a/neural_seq_qa/test/test_reader.py +++ b/neural_seq_qa/test/test_reader.py @@ -19,11 +19,12 @@ ch = logging.StreamHandler() ch.setFormatter(formatter) utils.logger.addHandler(ch) + class Vocab(object): @property def data(self): - word_dict_path = os.path.join( - topdir, "data", "embedding", "wordvecs.vcb") + word_dict_path = os.path.join(topdir, "data", "embedding", + "wordvecs.vcb") return utils.load_dict(word_dict_path) @@ -52,7 +53,7 @@ class NegativeSampleRatioTest(unittest.TestCase): def runTest(self): for ratio in [1., 0.25, 0.]: self.check_ratio(ratio) - + class KeepFirstBTest(unittest.TestCase): def runTest(self): @@ -103,7 +104,7 @@ class DictTest(unittest.TestCase): self.assertGreater(len(q_uniq_ids), 50) self.assertGreater(len(e_uniq_ids), 50) - + if __name__ == '__main__': unittest.main() -- GitLab