提交 ffa2568a 编写于 作者: P Peng Li

fix style problem

上级 db575c5b
...@@ -8,9 +8,10 @@ from datapoint import DataPoint, Evidence, EecommFeatures ...@@ -8,9 +8,10 @@ from datapoint import DataPoint, Evidence, EecommFeatures
import utils import utils
from utils import logger from utils import logger
__all__ = ["Q_IDS", "E_IDS", "LABELS", "QE_COMM", "EE_COMM", __all__ = [
"Q_IDS_STR", "E_IDS_STR", "LABELS_STR", "QE_COMM_STR", "EE_COMM_STR", "Q_IDS", "E_IDS", "LABELS", "QE_COMM", "EE_COMM", "Q_IDS_STR", "E_IDS_STR",
"Settings", "create_reader"] "LABELS_STR", "QE_COMM_STR", "EE_COMM_STR", "Settings", "create_reader"
]
# slot names # slot names
Q_IDS_STR = "q_ids" Q_IDS_STR = "q_ids"
...@@ -25,7 +26,6 @@ LABELS = 2 ...@@ -25,7 +26,6 @@ LABELS = 2
QE_COMM = 3 QE_COMM = 3
EE_COMM = 4 EE_COMM = 4
NO_ANSWER = "no_answer" NO_ANSWER = "no_answer"
...@@ -33,6 +33,7 @@ class Settings(object): ...@@ -33,6 +33,7 @@ class Settings(object):
""" """
class for storing settings class for storing settings
""" """
def __init__(self, def __init__(self,
vocab, vocab,
is_training, is_training,
...@@ -75,15 +76,23 @@ class Settings(object): ...@@ -75,15 +76,23 @@ class Settings(object):
elif label_schema == "BIO2": elif label_schema == "BIO2":
B, I, O1, O2 = 0, 1, 2, 3 B, I, O1, O2 = 0, 1, 2, 3
else: else:
raise ValueError("label_schema should be BIO/BIO2") raise ValueError("label_schema should be BIO/BIO2")
self.B, self.I, self.O1, self.O2 = B, I, O1, O2 self.B, self.I, self.O1, self.O2 = B, I, O1, O2
self.label_map = {"B":B, "I":I, "O1":O1, "O2":O2, self.label_map = {
"b":B, "i":I, "o1":O1, "o2":O2} "B": B,
"I": I,
"O1": O1,
"O2": O2,
"b": B,
"i": I,
"o1": O1,
"o2": O2
}
self.label_num = len(set((B, I, O1, O2))) self.label_num = len(set((B, I, O1, O2)))
# id for OOV # id for OOV
self.oov_id = 0 self.oov_id = 0
# set up random seed # set up random seed
random.seed(seed) random.seed(seed)
...@@ -94,7 +103,7 @@ class Settings(object): ...@@ -94,7 +103,7 @@ class Settings(object):
logger.info("keep_first_b: %s", keep_first_b) logger.info("keep_first_b: %s", keep_first_b)
logger.info("data reader random seed: %d", seed) logger.info("data reader random seed: %d", seed)
class SampleStream(object): class SampleStream(object):
def __init__(self, filename, settings): def __init__(self, filename, settings):
self.filename = filename self.filename = filename
...@@ -102,7 +111,7 @@ class SampleStream(object): ...@@ -102,7 +111,7 @@ class SampleStream(object):
def __iter__(self): def __iter__(self):
return self.load_and_filter_samples(self.filename) return self.load_and_filter_samples(self.filename)
def load_and_filter_samples(self, filename): def load_and_filter_samples(self, filename):
def remove_extra_b(labels): def remove_extra_b(labels):
if labels.count(self.settings.B) <= 1: return if labels.count(self.settings.B) <= 1: return
...@@ -111,7 +120,7 @@ class SampleStream(object): ...@@ -111,7 +120,7 @@ class SampleStream(object):
# find the first B # find the first B
while i < len(labels) and labels[i] == self.settings.O1: while i < len(labels) and labels[i] == self.settings.O1:
i += 1 i += 1
i += 1 # skip B i += 1 # skip B
# skip the following Is # skip the following Is
while i < len(labels) and labels[i] == self.settings.I: while i < len(labels) and labels[i] == self.settings.I:
i += 1 i += 1
...@@ -138,23 +147,22 @@ class SampleStream(object): ...@@ -138,23 +147,22 @@ class SampleStream(object):
# matches in training # matches in training
is_all_o1 = labels.count(self.settings.O1) == len(labels) is_all_o1 = labels.count(self.settings.O1) == len(labels)
if self.settings.is_training and is_all_o1 and not is_negative: if self.settings.is_training and is_all_o1 and not is_negative:
evidences[i] = None # dropped evidences[i] = None # dropped
continue continue
if self.settings.keep_first_b: if self.settings.keep_first_b:
remove_extra_b(labels) remove_extra_b(labels)
evi[Evidence.GOLDEN_LABELS] = labels evi[Evidence.GOLDEN_LABELS] = labels
def get_eecom_feats_list(cur_sample_is_negative, def get_eecom_feats_list(cur_sample_is_negative, eecom_feats_list,
eecom_feats_list,
evidences): evidences):
if not self.settings.is_training: if not self.settings.is_training:
return [item[EecommFeatures.EECOMM_FEATURES] \ return [item[EecommFeatures.EECOMM_FEATURES] \
for item in eecom_feats_list] for item in eecom_feats_list]
positive_eecom_feats_list = [] positive_eecom_feats_list = []
negative_eecom_feats_list = [] negative_eecom_feats_list = []
for eecom_feats_, other_evi in izip(eecom_feats_list, evidences): for eecom_feats_, other_evi in izip(eecom_feats_list, evidences):
if not other_evi: continue if not other_evi: continue
...@@ -174,7 +182,7 @@ class SampleStream(object): ...@@ -174,7 +182,7 @@ class SampleStream(object):
eecom_feats_list = positive_eecom_feats_list eecom_feats_list = positive_eecom_feats_list
if negative_eecom_feats_list: if negative_eecom_feats_list:
eecom_feats_list += [negative_eecom_feats_list] eecom_feats_list += [negative_eecom_feats_list]
return eecom_feats_list return eecom_feats_list
def process_tokens(data, tok_key): def process_tokens(data, tok_key):
...@@ -189,16 +197,15 @@ class SampleStream(object): ...@@ -189,16 +197,15 @@ class SampleStream(object):
qe_comm = evi[Evidence.QECOMM_FEATURES] qe_comm = evi[Evidence.QECOMM_FEATURES]
sample_type = evi[Evidence.TYPE] sample_type = evi[Evidence.TYPE]
ret = [None] * 5 ret = [None] * 5
ret[Q_IDS] = q_ids ret[Q_IDS] = q_ids
ret[E_IDS] = e_ids ret[E_IDS] = e_ids
ret[LABELS] = labels ret[LABELS] = labels
ret[QE_COMM] = qe_comm ret[QE_COMM] = qe_comm
eecom_feats_list = get_eecom_feats_list( eecom_feats_list = get_eecom_feats_list(
sample_type != Evidence.POSITIVE, sample_type != Evidence.POSITIVE,
evi[Evidence.EECOMM_FEATURES_LIST], evi[Evidence.EECOMM_FEATURES_LIST], evidences)
evidences)
if not eecom_feats_list: if not eecom_feats_list:
return None return None
else: else:
...@@ -217,7 +224,7 @@ class SampleStream(object): ...@@ -217,7 +224,7 @@ class SampleStream(object):
# convert question tokens to ids # convert question tokens to ids
q_ids = process_tokens(data, DataPoint.Q_TOKENS) q_ids = process_tokens(data, DataPoint.Q_TOKENS)
# process evidences # process evidences
evidences = data[DataPoint.EVIDENCES] evidences = data[DataPoint.EVIDENCES]
filter_and_preprocess_evidences(evidences) filter_and_preprocess_evidences(evidences)
...@@ -226,13 +233,14 @@ class SampleStream(object): ...@@ -226,13 +233,14 @@ class SampleStream(object):
sample = process_evi(q_ids, evi, evidences) sample = process_evi(q_ids, evi, evidences)
if sample: yield q_idx, sample, evi[Evidence.TYPE] if sample: yield q_idx, sample, evi[Evidence.TYPE]
class DataReader(object): class DataReader(object):
def __iter__(self): def __iter__(self):
return self return self
def _next(self): def _next(self):
raise NotImplemented() raise NotImplemented()
def next(self): def next(self):
data_point = self._next() data_point = self._next()
return self.post_process_sample(data_point) return self.post_process_sample(data_point)
...@@ -251,10 +259,7 @@ class DataReader(object): ...@@ -251,10 +259,7 @@ class DataReader(object):
class TrainingDataReader(DataReader): class TrainingDataReader(DataReader):
def __init__(self, def __init__(self, sample_stream, negative_ratio, hit_ans_negative_ratio):
sample_stream,
negative_ratio,
hit_ans_negative_ratio):
super(TrainingDataReader, self).__init__() super(TrainingDataReader, self).__init__()
self.positive_data = [] self.positive_data = []
self.hit_ans_negative_data = [] self.hit_ans_negative_data = []
...@@ -308,8 +313,8 @@ class TrainingDataReader(DataReader): ...@@ -308,8 +313,8 @@ class TrainingDataReader(DataReader):
if len(self.positive_data) == 0: if len(self.positive_data) == 0:
logger.fatal("zero positive sample") logger.fatal("zero positive sample")
raise ValueError("zero positive sample") raise ValueError("zero positive sample")
zero_hit = len(self.hit_ans_negative_data) == 0 zero_hit = len(self.hit_ans_negative_data) == 0
zero_other = len(self.other_negative_data) == 0 zero_other = len(self.other_negative_data) == 0
if zero_hit and zero_other: if zero_hit and zero_other:
...@@ -335,7 +340,7 @@ class TrainingDataReader(DataReader): ...@@ -335,7 +340,7 @@ class TrainingDataReader(DataReader):
self.p_idx = 0 self.p_idx = 0
self.p_idx += 1 self.p_idx += 1
return self.positive_data[self.p_idx-1] return self.positive_data[self.p_idx - 1]
def _next_negative_data(self, idx, negative_data): def _next_negative_data(self, idx, negative_data):
if idx >= len(negative_data): if idx >= len(negative_data):
...@@ -352,16 +357,16 @@ class TrainingDataReader(DataReader): ...@@ -352,16 +357,16 @@ class TrainingDataReader(DataReader):
random.shuffle(bundle[0]) random.shuffle(bundle[0])
bundle[1] = 0 bundle[1] = 0
bundle[1] += 1 bundle[1] += 1
return idx+1, bundle[0][bundle[1]-1] return idx + 1, bundle[0][bundle[1] - 1]
def next_hit_ans_negative_data(self): def next_hit_ans_negative_data(self):
self.hit_idx, data = self._next_negative_data( self.hit_idx, data = self._next_negative_data(
self.hit_idx, self.hit_ans_negative_data) self.hit_idx, self.hit_ans_negative_data)
return data return data
def next_other_negative_data(self): def next_other_negative_data(self):
self.other_idx, data = self._next_negative_data( self.other_idx, data = self._next_negative_data(
self.other_idx, self.other_negative_data) self.other_idx, self.other_negative_data)
return data return data
def _next(self): def _next(self):
...@@ -387,16 +392,18 @@ class TestDataReader(DataReader): ...@@ -387,16 +392,18 @@ class TestDataReader(DataReader):
def create_reader(filename, settings, samples_per_pass=sys.maxint): def create_reader(filename, settings, samples_per_pass=sys.maxint):
if settings.is_training: if settings.is_training:
training_reader = TrainingDataReader( training_reader = TrainingDataReader(
SampleStream(filename, settings), SampleStream(filename, settings), settings.negative_sample_ratio,
settings.negative_sample_ratio, settings.hit_ans_negative_sample_ratio)
settings.hit_ans_negative_sample_ratio)
def wrapper(): def wrapper():
for i, data in izip(xrange(samples_per_pass), training_reader): for i, data in izip(xrange(samples_per_pass), training_reader):
yield data yield data
return wrapper return wrapper
else: else:
def wrapper(): def wrapper():
sample_stream = SampleStream(filename, settings) sample_stream = SampleStream(filename, settings)
return TestDataReader(sample_stream) return TestDataReader(sample_stream)
return wrapper return wrapper
...@@ -19,11 +19,12 @@ ch = logging.StreamHandler() ...@@ -19,11 +19,12 @@ ch = logging.StreamHandler()
ch.setFormatter(formatter) ch.setFormatter(formatter)
utils.logger.addHandler(ch) utils.logger.addHandler(ch)
class Vocab(object): class Vocab(object):
@property @property
def data(self): def data(self):
word_dict_path = os.path.join( word_dict_path = os.path.join(topdir, "data", "embedding",
topdir, "data", "embedding", "wordvecs.vcb") "wordvecs.vcb")
return utils.load_dict(word_dict_path) return utils.load_dict(word_dict_path)
...@@ -52,7 +53,7 @@ class NegativeSampleRatioTest(unittest.TestCase): ...@@ -52,7 +53,7 @@ class NegativeSampleRatioTest(unittest.TestCase):
def runTest(self): def runTest(self):
for ratio in [1., 0.25, 0.]: for ratio in [1., 0.25, 0.]:
self.check_ratio(ratio) self.check_ratio(ratio)
class KeepFirstBTest(unittest.TestCase): class KeepFirstBTest(unittest.TestCase):
def runTest(self): def runTest(self):
...@@ -103,7 +104,7 @@ class DictTest(unittest.TestCase): ...@@ -103,7 +104,7 @@ class DictTest(unittest.TestCase):
self.assertGreater(len(q_uniq_ids), 50) self.assertGreater(len(q_uniq_ids), 50)
self.assertGreater(len(e_uniq_ids), 50) self.assertGreater(len(e_uniq_ids), 50)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册