提交 5c5a64a3 编写于 作者: S Superjom

utils.py refactor

上级 4f13dbfa
...@@ -6,56 +6,100 @@ logger = logging.getLogger("logger") ...@@ -6,56 +6,100 @@ logger = logging.getLogger("logger")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class TaskType: def mode_attr_name(mode):
TRAIN_MODE = 0 return mode.upper() + '_MODE'
TEST_MODE = 1
INFER_MODE = 2
def __init__(self, mode):
self.mode = mode
def is_train(self): def create_attrs(cls):
return self.mode == self.TRAIN_MODE for id, mode in enumerate(cls.modes):
setattr(cls, mode_attr_name(mode), id)
def make_check_method(cls):
'''
create methods for classes.
'''
def method(mode):
def _method(self):
return self.mode == getattr(cls, mode_attr_name(mode))
return _method
for id, mode in enumerate(cls.modes):
setattr(cls, 'is_' + mode, method(mode))
def make_create_method(cls):
def method(mode):
@staticmethod
def _method():
key = getattr(cls, mode_attr_name(mode))
return cls(key)
return _method
for id, mode in enumerate(cls.modes):
setattr(cls, 'create_' + mode, method(mode))
def is_test(self):
return self.mode == self.TEST_MODE
def is_infer(self): def make_str_method(cls):
return self.mode == self.INFER_MODE def _str_(self):
for mode in cls.modes:
if self.mode == getattr(cls, mode_attr_name(mode)):
return mode
@staticmethod def _hash_(self):
def create_train(): return self.mode
return TaskType(TaskType.TRAIN_MODE)
@staticmethod setattr(cls, '__str__', _str_)
def create_test(): setattr(cls, '__repr__', _str_)
return TaskType(TaskType.TEST_MODE) setattr(cls, '__hash__', _hash_)
@staticmethod
def create_infer(): def _init_(self, mode, cls):
return TaskType(TaskType.INFER_MODE) if isinstance(mode, int):
self.mode = mode
elif isinstance(mode, cls):
self.mode = mode.mode
else:
raise
def build_mode_class(cls):
create_attrs(cls)
make_str_method(cls)
make_check_method(cls)
make_create_method(cls)
class TaskType(object):
# TRAIN_MODE = 0
# TEST_MODE = 1
# INFER_MODE = 2
modes = 'train test infer'.split()
def __init__(self, mode):
_init_(self, mode, TaskType)
class ModelType: class ModelType:
CLASSIFICATION = 0 modes = 'classification rank regression'.split()
RANK = 1
def __init__(self, mode): def __init__(self, mode):
self.mode = mode _init_(self, mode, ModelType)
def is_classification(self): class ModelArch:
return self.mode == self.CLASSIFICATION modes = 'fc cnn rnn'.split()
def is_rank(self): def __init__(self, mode):
return self.mode == self.RANK _init_(self, mode, ModelArch)
@staticmethod
def create_classification():
return ModelType(ModelType.CLASSIFICATION)
@staticmethod build_mode_class(TaskType)
def create_rank(): build_mode_class(ModelType)
return ModelType(ModelType.RANK) build_mode_class(ModelArch)
def sent2ids(sent, vocab): def sent2ids(sent, vocab):
...@@ -81,3 +125,10 @@ def load_dic(path): ...@@ -81,3 +125,10 @@ def load_dic(path):
w = line.strip() w = line.strip()
dic[w] = id dic[w] = id
return dic return dic
if __name__ == '__main__':
t = TaskType(1)
t = TaskType.create_train()
print t
print 'is', t.is_train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册