提交 ffa2568a 编写于 作者: P Peng Li

fix style problem

上级 db575c5b
......@@ -8,9 +8,10 @@ from datapoint import DataPoint, Evidence, EecommFeatures
import utils
from utils import logger
__all__ = ["Q_IDS", "E_IDS", "LABELS", "QE_COMM", "EE_COMM",
"Q_IDS_STR", "E_IDS_STR", "LABELS_STR", "QE_COMM_STR", "EE_COMM_STR",
"Settings", "create_reader"]
__all__ = [
"Q_IDS", "E_IDS", "LABELS", "QE_COMM", "EE_COMM", "Q_IDS_STR", "E_IDS_STR",
"LABELS_STR", "QE_COMM_STR", "EE_COMM_STR", "Settings", "create_reader"
]
# slot names
Q_IDS_STR = "q_ids"
......@@ -25,7 +26,6 @@ LABELS = 2
QE_COMM = 3
EE_COMM = 4
NO_ANSWER = "no_answer"
......@@ -33,6 +33,7 @@ class Settings(object):
"""
class for storing settings
"""
def __init__(self,
vocab,
is_training,
......@@ -77,8 +78,16 @@ class Settings(object):
else:
raise ValueError("label_schema should be BIO/BIO2")
self.B, self.I, self.O1, self.O2 = B, I, O1, O2
self.label_map = {"B":B, "I":I, "O1":O1, "O2":O2,
"b":B, "i":I, "o1":O1, "o2":O2}
self.label_map = {
"B": B,
"I": I,
"O1": O1,
"O2": O2,
"b": B,
"i": I,
"o1": O1,
"o2": O2
}
self.label_num = len(set((B, I, O1, O2)))
# id for OOV
......@@ -145,8 +154,7 @@ class SampleStream(object):
remove_extra_b(labels)
evi[Evidence.GOLDEN_LABELS] = labels
def get_eecom_feats_list(cur_sample_is_negative,
eecom_feats_list,
def get_eecom_feats_list(cur_sample_is_negative, eecom_feats_list,
evidences):
if not self.settings.is_training:
return [item[EecommFeatures.EECOMM_FEATURES] \
......@@ -197,8 +205,7 @@ class SampleStream(object):
eecom_feats_list = get_eecom_feats_list(
sample_type != Evidence.POSITIVE,
evi[Evidence.EECOMM_FEATURES_LIST],
evidences)
evi[Evidence.EECOMM_FEATURES_LIST], evidences)
if not eecom_feats_list:
return None
else:
......@@ -226,6 +233,7 @@ class SampleStream(object):
sample = process_evi(q_ids, evi, evidences)
if sample: yield q_idx, sample, evi[Evidence.TYPE]
class DataReader(object):
def __iter__(self):
return self
......@@ -251,10 +259,7 @@ class DataReader(object):
class TrainingDataReader(DataReader):
def __init__(self,
sample_stream,
negative_ratio,
hit_ans_negative_ratio):
def __init__(self, sample_stream, negative_ratio, hit_ans_negative_ratio):
super(TrainingDataReader, self).__init__()
self.positive_data = []
self.hit_ans_negative_data = []
......@@ -335,7 +340,7 @@ class TrainingDataReader(DataReader):
self.p_idx = 0
self.p_idx += 1
return self.positive_data[self.p_idx-1]
return self.positive_data[self.p_idx - 1]
def _next_negative_data(self, idx, negative_data):
if idx >= len(negative_data):
......@@ -352,7 +357,7 @@ class TrainingDataReader(DataReader):
random.shuffle(bundle[0])
bundle[1] = 0
bundle[1] += 1
return idx+1, bundle[0][bundle[1]-1]
return idx + 1, bundle[0][bundle[1] - 1]
def next_hit_ans_negative_data(self):
self.hit_idx, data = self._next_negative_data(
......@@ -387,16 +392,18 @@ class TestDataReader(DataReader):
def create_reader(filename, settings, samples_per_pass=sys.maxint):
if settings.is_training:
training_reader = TrainingDataReader(
SampleStream(filename, settings),
settings.negative_sample_ratio,
SampleStream(filename, settings), settings.negative_sample_ratio,
settings.hit_ans_negative_sample_ratio)
def wrapper():
for i, data in izip(xrange(samples_per_pass), training_reader):
yield data
return wrapper
else:
def wrapper():
sample_stream = SampleStream(filename, settings)
return TestDataReader(sample_stream)
return wrapper
......@@ -19,11 +19,12 @@ ch = logging.StreamHandler()
ch.setFormatter(formatter)
utils.logger.addHandler(ch)
class Vocab(object):
@property
def data(self):
word_dict_path = os.path.join(
topdir, "data", "embedding", "wordvecs.vcb")
word_dict_path = os.path.join(topdir, "data", "embedding",
"wordvecs.vcb")
return utils.load_dict(word_dict_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册