提交 2ead8cea 编写于 作者: S Superjom

finish classification

上级 df91e730
from paddle import v2 as paddle
from paddle.v2.attr import ParamAttr
from utils import TaskType, logger
from utils import TaskType, logger, ModelType
class DSSM(object):
def __init__(self,
dnn_dims=[],
vocab_sizes=[],
task_type=TaskType.CLASSFICATION,
model_type=ModelType.CLASSIFICATION,
share_semantic_generator=False,
class_num=None,
share_embed=False):
......@@ -16,7 +16,7 @@ class DSSM(object):
dimentions of each layer in semantic vector generator.
@vocab_sizes: 2-d tuple
size of both left and right items.
@task_type: str
@model_type: str
type of task, should be 'rank', 'regression' or 'classification'
@share_semantic_generator: bool
whether to share the semantic vector generator for both left and right.
......@@ -33,13 +33,13 @@ class DSSM(object):
self.vocab_sizes = vocab_sizes
self.share_semantic_generator = share_semantic_generator
self.share_embed = share_embed
self.task_type = task_type
self.model_type = model_type
self.class_num = class_num
logger.info("vocabulary sizes: %s" % str(self.vocab_sizes))
def __call__(self):
if self.task_type == TaskType.CLASSFICATION:
if self.model_type == ModelType.CLASSIFICATION:
return self._build_classification_model()
return self._build_rank_model()
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from utils import UNK, TaskType, load_dic, sent2ids, logger
from utils import UNK, ModelType, TaskType, load_dic, sent2ids, logger, ModelType
class Dataset(object):
......@@ -9,18 +9,18 @@ class Dataset(object):
test_path,
source_dic_path,
target_dic_path,
task_type=TaskType.RANK):
model_type=ModelType.RANK):
self.train_path = train_path
self.test_path = test_path
self.source_dic_path = source_dic_path
self.target_dic_path = target_dic_path
self.task_type = task_type
self.model_type = model_type
self.source_dic = load_dic(self.source_dic_path)
self.target_dic = load_dic(self.target_dic_path)
self.record_reader = self._read_classification_record \
if self.task_type == TaskType.CLASSFICATION \
if self.model_type == ModelType.CLASSIFICATION \
else self._read_rank_record
def train(self):
......@@ -75,7 +75,7 @@ if __name__ == '__main__':
test_path = './data/classification/test.txt'
source_dic = './data/vocab.txt'
dataset = Dataset(path, test_path, source_dic, source_dic,
TaskType.CLASSFICATION)
ModelType.CLASSIFICATION)
for rcd in dataset.train():
print rcd
......
......@@ -6,7 +6,7 @@ import gzip
import paddle.v2 as paddle
from network_conf import DSSM
import reader
from utils import TaskType, load_dic, logger
from utils import TaskType, load_dic, logger, ModelType
parser = argparse.ArgumentParser(description="PaddlePaddle DSSM example")
......@@ -42,10 +42,11 @@ parser.add_argument(
default=10,
help="number of passes to run(default:10)")
parser.add_argument(
'--task_type',
'--model_type',
type=int,
default=TaskType.CLASSFICATION,
help="task type, 0 for classification, 1 for pairwise rank")
default=ModelType.CLASSIFICATION,
help="model type, %d for classification, %d for pairwise rank (default: classification)"
% (ModelType.CLASSIFICATION, ModelType.RANK))
parser.add_argument(
'--share_network_between_source_target',
type=bool,
......@@ -66,6 +67,7 @@ parser.add_argument(
'--num_workers', type=int, default=1, help="num worker threads, default 1")
args = parser.parse_args()
args.model_type = ModelType(args.model_type)
layer_dims = [int(i) for i in args.dnn_dims.split(',')]
target_dic_path = args.source_dic_path if not args.target_dic_path else args.target_dic_path
......@@ -75,7 +77,7 @@ def train(train_data_path=None,
test_data_path=None,
source_dic_path=None,
target_dic_path=None,
task_type=TaskType.CLASSFICATION,
model_type=ModelType.CLASSIFICATION,
batch_size=10,
num_passes=10,
share_semantic_generator=False,
......@@ -88,7 +90,7 @@ def train(train_data_path=None,
default_train_path = './data/rank/train.txt'
default_test_path = './data/rank/test.txt'
default_dic_path = './data/vocab.txt'
if task_type == TaskType.CLASSFICATION:
if model_type == ModelType.CLASSIFICATION:
default_train_path = './data/classification/train.txt'
default_test_path = './data/classification/test.txt'
......@@ -105,7 +107,7 @@ def train(train_data_path=None,
test_path=test_data_path,
source_dic_path=source_dic_path,
target_dic_path=target_dic_path,
task_type=task_type, )
model_type=args.model_type, )
train_reader = paddle.batch(
paddle.reader.shuffle(dataset.train, buf_size=1000),
......@@ -122,7 +124,7 @@ def train(train_data_path=None,
vocab_sizes=[
len(load_dic(path)) for path in [source_dic_path, target_dic_path]
],
task_type=task_type,
model_type=model_type,
share_semantic_generator=share_semantic_generator,
class_num=class_num,
share_embed=share_embed)()
......@@ -142,7 +144,7 @@ def train(train_data_path=None,
update_equation=adam_optimizer)
feeding = {}
if task_type == TaskType.CLASSFICATION:
if model_type == ModelType.CLASSIFICATION:
feeding = {'source_input': 0, 'target_input': 1, 'label_input': 2}
else:
feeding = {
......@@ -163,7 +165,7 @@ def train(train_data_path=None,
if isinstance(event, paddle.event.EndPass):
if test_reader is not None:
if task_type == TaskType.CLASSFICATION:
if model_type == ModelType.CLASSIFICATION:
result = trainer.test(reader=test_reader, feeding=feeding)
logger.info("Test at Pass %d, %s \n" % (event.pass_id,
result.metrics))
......@@ -183,4 +185,4 @@ def train(train_data_path=None,
if __name__ == '__main__':
# train(class_num=2)
train(task_type=TaskType.RANK)
train(model_type=ModelType.RANK)
......@@ -7,13 +7,55 @@ logger.setLevel(logging.INFO)
class TaskType:
'''
type of DSSM's task.
'''
# pairwise rank.
RANK = 0
# classification.
CLASSFICATION = 1
TRAIN_MODE = 0
TEST_MODE = 1
INFER_MODE = 2
def __init__(self, mode):
self.mode = mode
def is_train(self):
return self.mode == self.TRAIN_MODE
def is_test(self):
return self.mode == self.TEST_MODE
def is_infer(self):
return self.mode == self.INFER_MODE
@staticmethod
def create_train():
return TaskType(TaskType.TRAIN_MODE)
@staticmethod
def create_test():
return TaskType(TaskType.TEST_MODE)
@staticmethod
def create_infer():
return TaskType(TaskType.INFER_MODE)
class ModelType:
CLASSIFICATION = 0
RANK = 1
def __init__(self, mode):
self.mode = mode
def is_classification(self):
return self.mode == self.CLASSIFICATION
def is_rank(self):
return self.mode == self.RANK
@staticmethod
def create_classification():
return ModelType(ModelType.CLASSIFICATION)
@staticmethod
def create_rank():
return ModelType(ModelType.RANK)
def sent2ids(sent, vocab):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册