提交 ffa2568a 编写于 作者: P Peng Li

fix style problem

上级 db575c5b
...@@ -8,9 +8,10 @@ from datapoint import DataPoint, Evidence, EecommFeatures ...@@ -8,9 +8,10 @@ from datapoint import DataPoint, Evidence, EecommFeatures
import utils import utils
from utils import logger from utils import logger
__all__ = ["Q_IDS", "E_IDS", "LABELS", "QE_COMM", "EE_COMM", __all__ = [
"Q_IDS_STR", "E_IDS_STR", "LABELS_STR", "QE_COMM_STR", "EE_COMM_STR", "Q_IDS", "E_IDS", "LABELS", "QE_COMM", "EE_COMM", "Q_IDS_STR", "E_IDS_STR",
"Settings", "create_reader"] "LABELS_STR", "QE_COMM_STR", "EE_COMM_STR", "Settings", "create_reader"
]
# slot names # slot names
Q_IDS_STR = "q_ids" Q_IDS_STR = "q_ids"
...@@ -25,7 +26,6 @@ LABELS = 2 ...@@ -25,7 +26,6 @@ LABELS = 2
QE_COMM = 3 QE_COMM = 3
EE_COMM = 4 EE_COMM = 4
NO_ANSWER = "no_answer" NO_ANSWER = "no_answer"
...@@ -33,6 +33,7 @@ class Settings(object): ...@@ -33,6 +33,7 @@ class Settings(object):
""" """
class for storing settings class for storing settings
""" """
def __init__(self, def __init__(self,
vocab, vocab,
is_training, is_training,
...@@ -77,8 +78,16 @@ class Settings(object): ...@@ -77,8 +78,16 @@ class Settings(object):
else: else:
raise ValueError("label_schema should be BIO/BIO2") raise ValueError("label_schema should be BIO/BIO2")
self.B, self.I, self.O1, self.O2 = B, I, O1, O2 self.B, self.I, self.O1, self.O2 = B, I, O1, O2
self.label_map = {"B":B, "I":I, "O1":O1, "O2":O2, self.label_map = {
"b":B, "i":I, "o1":O1, "o2":O2} "B": B,
"I": I,
"O1": O1,
"O2": O2,
"b": B,
"i": I,
"o1": O1,
"o2": O2
}
self.label_num = len(set((B, I, O1, O2))) self.label_num = len(set((B, I, O1, O2)))
# id for OOV # id for OOV
...@@ -145,8 +154,7 @@ class SampleStream(object): ...@@ -145,8 +154,7 @@ class SampleStream(object):
remove_extra_b(labels) remove_extra_b(labels)
evi[Evidence.GOLDEN_LABELS] = labels evi[Evidence.GOLDEN_LABELS] = labels
def get_eecom_feats_list(cur_sample_is_negative, def get_eecom_feats_list(cur_sample_is_negative, eecom_feats_list,
eecom_feats_list,
evidences): evidences):
if not self.settings.is_training: if not self.settings.is_training:
return [item[EecommFeatures.EECOMM_FEATURES] \ return [item[EecommFeatures.EECOMM_FEATURES] \
...@@ -197,8 +205,7 @@ class SampleStream(object): ...@@ -197,8 +205,7 @@ class SampleStream(object):
eecom_feats_list = get_eecom_feats_list( eecom_feats_list = get_eecom_feats_list(
sample_type != Evidence.POSITIVE, sample_type != Evidence.POSITIVE,
evi[Evidence.EECOMM_FEATURES_LIST], evi[Evidence.EECOMM_FEATURES_LIST], evidences)
evidences)
if not eecom_feats_list: if not eecom_feats_list:
return None return None
else: else:
...@@ -226,6 +233,7 @@ class SampleStream(object): ...@@ -226,6 +233,7 @@ class SampleStream(object):
sample = process_evi(q_ids, evi, evidences) sample = process_evi(q_ids, evi, evidences)
if sample: yield q_idx, sample, evi[Evidence.TYPE] if sample: yield q_idx, sample, evi[Evidence.TYPE]
class DataReader(object): class DataReader(object):
def __iter__(self): def __iter__(self):
return self return self
...@@ -251,10 +259,7 @@ class DataReader(object): ...@@ -251,10 +259,7 @@ class DataReader(object):
class TrainingDataReader(DataReader): class TrainingDataReader(DataReader):
def __init__(self, def __init__(self, sample_stream, negative_ratio, hit_ans_negative_ratio):
sample_stream,
negative_ratio,
hit_ans_negative_ratio):
super(TrainingDataReader, self).__init__() super(TrainingDataReader, self).__init__()
self.positive_data = [] self.positive_data = []
self.hit_ans_negative_data = [] self.hit_ans_negative_data = []
...@@ -335,7 +340,7 @@ class TrainingDataReader(DataReader): ...@@ -335,7 +340,7 @@ class TrainingDataReader(DataReader):
self.p_idx = 0 self.p_idx = 0
self.p_idx += 1 self.p_idx += 1
return self.positive_data[self.p_idx-1] return self.positive_data[self.p_idx - 1]
def _next_negative_data(self, idx, negative_data): def _next_negative_data(self, idx, negative_data):
if idx >= len(negative_data): if idx >= len(negative_data):
...@@ -352,7 +357,7 @@ class TrainingDataReader(DataReader): ...@@ -352,7 +357,7 @@ class TrainingDataReader(DataReader):
random.shuffle(bundle[0]) random.shuffle(bundle[0])
bundle[1] = 0 bundle[1] = 0
bundle[1] += 1 bundle[1] += 1
return idx+1, bundle[0][bundle[1]-1] return idx + 1, bundle[0][bundle[1] - 1]
def next_hit_ans_negative_data(self): def next_hit_ans_negative_data(self):
self.hit_idx, data = self._next_negative_data( self.hit_idx, data = self._next_negative_data(
...@@ -387,16 +392,18 @@ class TestDataReader(DataReader): ...@@ -387,16 +392,18 @@ class TestDataReader(DataReader):
def create_reader(filename, settings, samples_per_pass=sys.maxint): def create_reader(filename, settings, samples_per_pass=sys.maxint):
if settings.is_training: if settings.is_training:
training_reader = TrainingDataReader( training_reader = TrainingDataReader(
SampleStream(filename, settings), SampleStream(filename, settings), settings.negative_sample_ratio,
settings.negative_sample_ratio,
settings.hit_ans_negative_sample_ratio) settings.hit_ans_negative_sample_ratio)
def wrapper(): def wrapper():
for i, data in izip(xrange(samples_per_pass), training_reader): for i, data in izip(xrange(samples_per_pass), training_reader):
yield data yield data
return wrapper return wrapper
else: else:
def wrapper(): def wrapper():
sample_stream = SampleStream(filename, settings) sample_stream = SampleStream(filename, settings)
return TestDataReader(sample_stream) return TestDataReader(sample_stream)
return wrapper return wrapper
...@@ -19,11 +19,12 @@ ch = logging.StreamHandler() ...@@ -19,11 +19,12 @@ ch = logging.StreamHandler()
ch.setFormatter(formatter) ch.setFormatter(formatter)
utils.logger.addHandler(ch) utils.logger.addHandler(ch)
class Vocab(object): class Vocab(object):
@property @property
def data(self): def data(self):
word_dict_path = os.path.join( word_dict_path = os.path.join(topdir, "data", "embedding",
topdir, "data", "embedding", "wordvecs.vcb") "wordvecs.vcb")
return utils.load_dict(word_dict_path) return utils.load_dict(word_dict_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册