提交 5c5a64a3 编写于 作者: S Superjom

utils.py refactor

上级 4f13dbfa
......@@ -6,56 +6,100 @@ logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
class TaskType:
TRAIN_MODE = 0
TEST_MODE = 1
INFER_MODE = 2
def mode_attr_name(mode):
return mode.upper() + '_MODE'
def __init__(self, mode):
self.mode = mode
def is_train(self):
return self.mode == self.TRAIN_MODE
def create_attrs(cls):
for id, mode in enumerate(cls.modes):
setattr(cls, mode_attr_name(mode), id)
def make_check_method(cls):
'''
create methods for classes.
'''
def method(mode):
def _method(self):
return self.mode == getattr(cls, mode_attr_name(mode))
return _method
for id, mode in enumerate(cls.modes):
setattr(cls, 'is_' + mode, method(mode))
def make_create_method(cls):
def method(mode):
@staticmethod
def _method():
key = getattr(cls, mode_attr_name(mode))
return cls(key)
return _method
for id, mode in enumerate(cls.modes):
setattr(cls, 'create_' + mode, method(mode))
def is_test(self):
return self.mode == self.TEST_MODE
def is_infer(self):
return self.mode == self.INFER_MODE
def make_str_method(cls):
def _str_(self):
for mode in cls.modes:
if self.mode == getattr(cls, mode_attr_name(mode)):
return mode
@staticmethod
def create_train():
return TaskType(TaskType.TRAIN_MODE)
def _hash_(self):
return self.mode
@staticmethod
def create_test():
return TaskType(TaskType.TEST_MODE)
setattr(cls, '__str__', _str_)
setattr(cls, '__repr__', _str_)
setattr(cls, '__hash__', _hash_)
@staticmethod
def create_infer():
return TaskType(TaskType.INFER_MODE)
def _init_(self, mode, cls):
if isinstance(mode, int):
self.mode = mode
elif isinstance(mode, cls):
self.mode = mode.mode
else:
raise
def build_mode_class(cls):
create_attrs(cls)
make_str_method(cls)
make_check_method(cls)
make_create_method(cls)
class TaskType(object):
# TRAIN_MODE = 0
# TEST_MODE = 1
# INFER_MODE = 2
modes = 'train test infer'.split()
def __init__(self, mode):
_init_(self, mode, TaskType)
class ModelType:
CLASSIFICATION = 0
RANK = 1
modes = 'classification rank regression'.split()
def __init__(self, mode):
self.mode = mode
_init_(self, mode, ModelType)
def is_classification(self):
return self.mode == self.CLASSIFICATION
class ModelArch:
modes = 'fc cnn rnn'.split()
def is_rank(self):
return self.mode == self.RANK
def __init__(self, mode):
_init_(self, mode, ModelArch)
@staticmethod
def create_classification():
return ModelType(ModelType.CLASSIFICATION)
@staticmethod
def create_rank():
return ModelType(ModelType.RANK)
build_mode_class(TaskType)
build_mode_class(ModelType)
build_mode_class(ModelArch)
def sent2ids(sent, vocab):
......@@ -81,3 +125,10 @@ def load_dic(path):
w = line.strip()
dic[w] = id
return dic
if __name__ == '__main__':
t = TaskType(1)
t = TaskType.create_train()
print t
print 'is', t.is_train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册