提交 2ead8cea 编写于 作者: S Superjom

finish classification

上级 df91e730
from paddle import v2 as paddle from paddle import v2 as paddle
from paddle.v2.attr import ParamAttr from paddle.v2.attr import ParamAttr
from utils import TaskType, logger from utils import TaskType, logger, ModelType
class DSSM(object): class DSSM(object):
def __init__(self, def __init__(self,
dnn_dims=[], dnn_dims=[],
vocab_sizes=[], vocab_sizes=[],
task_type=TaskType.CLASSFICATION, model_type=ModelType.CLASSIFICATION,
share_semantic_generator=False, share_semantic_generator=False,
class_num=None, class_num=None,
share_embed=False): share_embed=False):
...@@ -16,7 +16,7 @@ class DSSM(object): ...@@ -16,7 +16,7 @@ class DSSM(object):
dimentions of each layer in semantic vector generator. dimentions of each layer in semantic vector generator.
@vocab_sizes: 2-d tuple @vocab_sizes: 2-d tuple
size of both left and right items. size of both left and right items.
@task_type: str @model_type: str
type of task, should be 'rank', 'regression' or 'classification' type of task, should be 'rank', 'regression' or 'classification'
@share_semantic_generator: bool @share_semantic_generator: bool
whether to share the semantic vector generator for both left and right. whether to share the semantic vector generator for both left and right.
...@@ -33,13 +33,13 @@ class DSSM(object): ...@@ -33,13 +33,13 @@ class DSSM(object):
self.vocab_sizes = vocab_sizes self.vocab_sizes = vocab_sizes
self.share_semantic_generator = share_semantic_generator self.share_semantic_generator = share_semantic_generator
self.share_embed = share_embed self.share_embed = share_embed
self.task_type = task_type self.model_type = model_type
self.class_num = class_num self.class_num = class_num
logger.info("vocabulary sizes: %s" % str(self.vocab_sizes)) logger.info("vocabulary sizes: %s" % str(self.vocab_sizes))
def __call__(self): def __call__(self):
if self.task_type == TaskType.CLASSFICATION: if self.model_type == ModelType.CLASSIFICATION:
return self._build_classification_model() return self._build_classification_model()
return self._build_rank_model() return self._build_rank_model()
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from utils import UNK, TaskType, load_dic, sent2ids, logger from utils import UNK, ModelType, TaskType, load_dic, sent2ids, logger, ModelType
class Dataset(object): class Dataset(object):
...@@ -9,18 +9,18 @@ class Dataset(object): ...@@ -9,18 +9,18 @@ class Dataset(object):
test_path, test_path,
source_dic_path, source_dic_path,
target_dic_path, target_dic_path,
task_type=TaskType.RANK): model_type=ModelType.RANK):
self.train_path = train_path self.train_path = train_path
self.test_path = test_path self.test_path = test_path
self.source_dic_path = source_dic_path self.source_dic_path = source_dic_path
self.target_dic_path = target_dic_path self.target_dic_path = target_dic_path
self.task_type = task_type self.model_type = model_type
self.source_dic = load_dic(self.source_dic_path) self.source_dic = load_dic(self.source_dic_path)
self.target_dic = load_dic(self.target_dic_path) self.target_dic = load_dic(self.target_dic_path)
self.record_reader = self._read_classification_record \ self.record_reader = self._read_classification_record \
if self.task_type == TaskType.CLASSFICATION \ if self.model_type == ModelType.CLASSIFICATION \
else self._read_rank_record else self._read_rank_record
def train(self): def train(self):
...@@ -75,7 +75,7 @@ if __name__ == '__main__': ...@@ -75,7 +75,7 @@ if __name__ == '__main__':
test_path = './data/classification/test.txt' test_path = './data/classification/test.txt'
source_dic = './data/vocab.txt' source_dic = './data/vocab.txt'
dataset = Dataset(path, test_path, source_dic, source_dic, dataset = Dataset(path, test_path, source_dic, source_dic,
TaskType.CLASSFICATION) ModelType.CLASSIFICATION)
for rcd in dataset.train(): for rcd in dataset.train():
print rcd print rcd
......
...@@ -6,7 +6,7 @@ import gzip ...@@ -6,7 +6,7 @@ import gzip
import paddle.v2 as paddle import paddle.v2 as paddle
from network_conf import DSSM from network_conf import DSSM
import reader import reader
from utils import TaskType, load_dic, logger from utils import TaskType, load_dic, logger, ModelType
parser = argparse.ArgumentParser(description="PaddlePaddle DSSM example") parser = argparse.ArgumentParser(description="PaddlePaddle DSSM example")
...@@ -42,10 +42,11 @@ parser.add_argument( ...@@ -42,10 +42,11 @@ parser.add_argument(
default=10, default=10,
help="number of passes to run(default:10)") help="number of passes to run(default:10)")
parser.add_argument( parser.add_argument(
'--task_type', '--model_type',
type=int, type=int,
default=TaskType.CLASSFICATION, default=ModelType.CLASSIFICATION,
help="task type, 0 for classification, 1 for pairwise rank") help="model type, %d for classification, %d for pairwise rank (default: classification)"
% (ModelType.CLASSIFICATION, ModelType.RANK))
parser.add_argument( parser.add_argument(
'--share_network_between_source_target', '--share_network_between_source_target',
type=bool, type=bool,
...@@ -66,6 +67,7 @@ parser.add_argument( ...@@ -66,6 +67,7 @@ parser.add_argument(
'--num_workers', type=int, default=1, help="num worker threads, default 1") '--num_workers', type=int, default=1, help="num worker threads, default 1")
args = parser.parse_args() args = parser.parse_args()
args.model_type = ModelType(args.model_type)
layer_dims = [int(i) for i in args.dnn_dims.split(',')] layer_dims = [int(i) for i in args.dnn_dims.split(',')]
target_dic_path = args.source_dic_path if not args.target_dic_path else args.target_dic_path target_dic_path = args.source_dic_path if not args.target_dic_path else args.target_dic_path
...@@ -75,7 +77,7 @@ def train(train_data_path=None, ...@@ -75,7 +77,7 @@ def train(train_data_path=None,
test_data_path=None, test_data_path=None,
source_dic_path=None, source_dic_path=None,
target_dic_path=None, target_dic_path=None,
task_type=TaskType.CLASSFICATION, model_type=ModelType.CLASSIFICATION,
batch_size=10, batch_size=10,
num_passes=10, num_passes=10,
share_semantic_generator=False, share_semantic_generator=False,
...@@ -88,7 +90,7 @@ def train(train_data_path=None, ...@@ -88,7 +90,7 @@ def train(train_data_path=None,
default_train_path = './data/rank/train.txt' default_train_path = './data/rank/train.txt'
default_test_path = './data/rank/test.txt' default_test_path = './data/rank/test.txt'
default_dic_path = './data/vocab.txt' default_dic_path = './data/vocab.txt'
if task_type == TaskType.CLASSFICATION: if model_type == ModelType.CLASSIFICATION:
default_train_path = './data/classification/train.txt' default_train_path = './data/classification/train.txt'
default_test_path = './data/classification/test.txt' default_test_path = './data/classification/test.txt'
...@@ -105,7 +107,7 @@ def train(train_data_path=None, ...@@ -105,7 +107,7 @@ def train(train_data_path=None,
test_path=test_data_path, test_path=test_data_path,
source_dic_path=source_dic_path, source_dic_path=source_dic_path,
target_dic_path=target_dic_path, target_dic_path=target_dic_path,
task_type=task_type, ) model_type=args.model_type, )
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle(dataset.train, buf_size=1000), paddle.reader.shuffle(dataset.train, buf_size=1000),
...@@ -122,7 +124,7 @@ def train(train_data_path=None, ...@@ -122,7 +124,7 @@ def train(train_data_path=None,
vocab_sizes=[ vocab_sizes=[
len(load_dic(path)) for path in [source_dic_path, target_dic_path] len(load_dic(path)) for path in [source_dic_path, target_dic_path]
], ],
task_type=task_type, model_type=model_type,
share_semantic_generator=share_semantic_generator, share_semantic_generator=share_semantic_generator,
class_num=class_num, class_num=class_num,
share_embed=share_embed)() share_embed=share_embed)()
...@@ -142,7 +144,7 @@ def train(train_data_path=None, ...@@ -142,7 +144,7 @@ def train(train_data_path=None,
update_equation=adam_optimizer) update_equation=adam_optimizer)
feeding = {} feeding = {}
if task_type == TaskType.CLASSFICATION: if model_type == ModelType.CLASSIFICATION:
feeding = {'source_input': 0, 'target_input': 1, 'label_input': 2} feeding = {'source_input': 0, 'target_input': 1, 'label_input': 2}
else: else:
feeding = { feeding = {
...@@ -163,7 +165,7 @@ def train(train_data_path=None, ...@@ -163,7 +165,7 @@ def train(train_data_path=None,
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
if test_reader is not None: if test_reader is not None:
if task_type == TaskType.CLASSFICATION: if model_type == ModelType.CLASSIFICATION:
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
logger.info("Test at Pass %d, %s \n" % (event.pass_id, logger.info("Test at Pass %d, %s \n" % (event.pass_id,
result.metrics)) result.metrics))
...@@ -183,4 +185,4 @@ def train(train_data_path=None, ...@@ -183,4 +185,4 @@ def train(train_data_path=None,
if __name__ == '__main__': if __name__ == '__main__':
# train(class_num=2) # train(class_num=2)
train(task_type=TaskType.RANK) train(model_type=ModelType.RANK)
...@@ -7,13 +7,55 @@ logger.setLevel(logging.INFO) ...@@ -7,13 +7,55 @@ logger.setLevel(logging.INFO)
class TaskType: class TaskType:
''' TRAIN_MODE = 0
type of DSSM's task. TEST_MODE = 1
''' INFER_MODE = 2
# pairwise rank.
RANK = 0 def __init__(self, mode):
# classification. self.mode = mode
CLASSFICATION = 1
def is_train(self):
return self.mode == self.TRAIN_MODE
def is_test(self):
return self.mode == self.TEST_MODE
def is_infer(self):
return self.mode == self.INFER_MODE
@staticmethod
def create_train():
return TaskType(TaskType.TRAIN_MODE)
@staticmethod
def create_test():
return TaskType(TaskType.TEST_MODE)
@staticmethod
def create_infer():
return TaskType(TaskType.INFER_MODE)
class ModelType:
CLASSIFICATION = 0
RANK = 1
def __init__(self, mode):
self.mode = mode
def is_classification(self):
return self.mode == self.CLASSIFICATION
def is_rank(self):
return self.mode == self.RANK
@staticmethod
def create_classification():
return ModelType(ModelType.CLASSIFICATION)
@staticmethod
def create_rank():
return ModelType(ModelType.RANK)
def sent2ids(sent, vocab): def sent2ids(sent, vocab):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册