提交 34654503 编写于 作者: S Superjom

fix regression bug

上级 b645b46b
...@@ -64,12 +64,14 @@ class DSSM(object): ...@@ -64,12 +64,14 @@ class DSSM(object):
'rank': self._build_rank_model, 'rank': self._build_rank_model,
'regression': self._build_regression_model, 'regression': self._build_regression_model,
} }
print 'model type: ', str(self.model_type)
self.model_type_creater = _model_type[str(self.model_type)] self.model_type_creater = _model_type[str(self.model_type)]
def __call__(self): def __call__(self):
if self.model_type.is_classification(): # if self.model_type.is_classification():
return self._build_classification_model() # return self._build_classification_model()
return self._build_rank_model() # return self._build_rank_model()
return self.model_type_creater()
def create_embedding(self, input, prefix=''): def create_embedding(self, input, prefix=''):
''' '''
...@@ -155,10 +157,14 @@ class DSSM(object): ...@@ -155,10 +157,14 @@ class DSSM(object):
return _input_layer return _input_layer
def _build_classification_model(self): def _build_classification_model(self):
logger.info("build classification model")
assert self.model_type.is_classification()
return self._build_classification_or_regression_model( return self._build_classification_or_regression_model(
is_classification=True) is_classification=True)
def _build_regression_model(self): def _build_regression_model(self):
logger.info("build regression model")
assert self.model_type.is_regression()
return self._build_classification_or_regression_model( return self._build_classification_or_regression_model(
is_classification=False) is_classification=False)
...@@ -172,6 +178,8 @@ class DSSM(object): ...@@ -172,6 +178,8 @@ class DSSM(object):
- right_target sentence - right_target sentence
- label, 1 if left_target should be sorted in front of right_target, otherwise 0. - label, 1 if left_target should be sorted in front of right_target, otherwise 0.
''' '''
logger.info("build rank model")
assert self.model_type.is_rank()
source = paddle.layer.data( source = paddle.layer.data(
name='source_input', name='source_input',
type=paddle.data_type.integer_value_sequence(self.vocab_sizes[0])) type=paddle.data_type.integer_value_sequence(self.vocab_sizes[0]))
...@@ -221,8 +229,9 @@ class DSSM(object): ...@@ -221,8 +229,9 @@ class DSSM(object):
- classification label - classification label
''' '''
# prepare inputs. if is_classification:
assert self.class_num # prepare inputs.
assert self.class_num
source = paddle.layer.data( source = paddle.layer.data(
name='source_input', name='source_input',
...@@ -233,7 +242,7 @@ class DSSM(object): ...@@ -233,7 +242,7 @@ class DSSM(object):
label = paddle.layer.data( label = paddle.layer.data(
name='label_input', name='label_input',
type=paddle.data_type.integer_value(self.class_num) type=paddle.data_type.integer_value(self.class_num)
if is_classification else paddle.data_type.dense_input) if is_classification else paddle.data_type.dense_vector(1))
prefixs = '_ _'.split( prefixs = '_ _'.split(
) if self.share_semantic_generator else 'left right'.split() ) if self.share_semantic_generator else 'left right'.split()
...@@ -250,15 +259,17 @@ class DSSM(object): ...@@ -250,15 +259,17 @@ class DSSM(object):
x = self.model_arch_creater(input, prefix=prefixs[id]) x = self.model_arch_creater(input, prefix=prefixs[id])
semantics.append(x) semantics.append(x)
concated_vector = paddle.layer.concat(semantics) if is_classification:
prediction = paddle.layer.fc( concated_vector = paddle.layer.concat(semantics)
input=concated_vector, prediction = paddle.layer.fc(
size=self.class_num, input=concated_vector,
act=paddle.activation.Softmax()) size=self.class_num,
cost = paddle.layer.classification_cost( act=paddle.activation.Softmax())
input=prediction, cost = paddle.layer.classification_cost(
label=label) if is_classification else paddle.layer.mse_cost( input=prediction, label=label)
prediction, label) else:
prediction = paddle.layer.cos_sim(*semantics)
cost = paddle.layer.mse_cost(prediction, label)
return cost, prediction, label return cost, prediction, label
......
...@@ -15,9 +15,14 @@ class Dataset(object): ...@@ -15,9 +15,14 @@ class Dataset(object):
self.source_dic = load_dic(self.source_dic_path) self.source_dic = load_dic(self.source_dic_path)
self.target_dic = load_dic(self.target_dic_path) self.target_dic = load_dic(self.target_dic_path)
self.record_reader = self._read_classification_record \ _record_reader = {
if self.model_type.is_classification() \ ModelType.CLASSIFICATION_MODE: self._read_classification_record,
else self._read_rank_record ModelType.REGRESSION_MODE: self._read_regression_record,
ModelType.RANK_MODE: self._read_rank_record,
}
assert isinstance(model_type, ModelType)
self.record_reader = _record_reader[model_type.mode]
def train(self): def train(self):
''' '''
...@@ -54,6 +59,23 @@ class Dataset(object): ...@@ -54,6 +59,23 @@ class Dataset(object):
label = int(fs[2]) label = int(fs[2])
return (source, target, label, ) return (source, target, label, )
def _read_regression_record(self, line):
'''
data format:
<source words> [TAB] <target words> [TAB] <label>
@line: str
a string line which represent a record.
'''
fs = line.strip().split('\t')
assert len(fs) == 3, "wrong format for regression\n" + \
"the format shoud be " +\
"<source words> [TAB] <target words> [TAB] <label>'"
source = sent2ids(fs[0], self.source_dic)
target = sent2ids(fs[1], self.target_dic)
label = float(fs[2])
return (source, target, [label], )
def _read_rank_record(self, line): def _read_rank_record(self, line):
''' '''
data format: data format:
......
...@@ -52,8 +52,9 @@ parser.add_argument( ...@@ -52,8 +52,9 @@ parser.add_argument(
type=int, type=int,
required=True, required=True,
default=ModelType.CLASSIFICATION_MODE, default=ModelType.CLASSIFICATION_MODE,
help="model type, %d for classification, %d for pairwise rank (default: classification)" help="model type, %d for classification, %d for pairwise rank, %d for regression (default: classification)"
% (ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE)) % (ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE,
ModelType.REGRESSION_MODE))
parser.add_argument( parser.add_argument(
'--model_arch', '--model_arch',
type=int, type=int,
...@@ -124,7 +125,7 @@ def train(train_data_path=None, ...@@ -124,7 +125,7 @@ def train(train_data_path=None,
default_train_path = './data/rank/train.txt' default_train_path = './data/rank/train.txt'
default_test_path = './data/rank/test.txt' default_test_path = './data/rank/test.txt'
default_dic_path = './data/vocab.txt' default_dic_path = './data/vocab.txt'
if model_type.is_classification(): if not model_type.is_rank():
default_train_path = './data/classification/train.txt' default_train_path = './data/classification/train.txt'
default_test_path = './data/classification/test.txt' default_test_path = './data/classification/test.txt'
...@@ -173,13 +174,18 @@ def train(train_data_path=None, ...@@ -173,13 +174,18 @@ def train(train_data_path=None,
trainer = paddle.trainer.SGD( trainer = paddle.trainer.SGD(
cost=cost, cost=cost,
extra_layers=paddle.evaluator.auc(input=prediction, label=label) extra_layers=None,
if prediction else None,
parameters=parameters, parameters=parameters,
update_equation=adam_optimizer) update_equation=adam_optimizer)
# trainer = paddle.trainer.SGD(
# cost=cost,
# extra_layers=paddle.evaluator.auc(input=prediction, label=label)
# if prediction and model_type.is_classification() else None,
# parameters=parameters,
# update_equation=adam_optimizer)
feeding = {} feeding = {}
if model_type.is_classification(): if model_type.is_classification() or model_type.is_regression():
feeding = {'source_input': 0, 'target_input': 1, 'label_input': 2} feeding = {'source_input': 0, 'target_input': 1, 'label_input': 2}
else: else:
feeding = { feeding = {
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
序列标注可以分为Sequence Classification、Segment Classification和Temporal Classification三类[[1](#参考文献)],本例只考虑Segment Classification,即对输入序列中的每个元素在输出序列中给出对应的标签。对于NER任务,由于需要标识边界,一般采用[BIO标注方法](http://book.paddlepaddle.org/07.label_semantic_roles/)定义的标签集,如下是一个NER的标注结果示例: 序列标注可以分为Sequence Classification、Segment Classification和Temporal Classification三类[[1](#参考文献)],本例只考虑Segment Classification,即对输入序列中的每个元素在输出序列中给出对应的标签。对于NER任务,由于需要标识边界,一般采用[BIO标注方法](http://book.paddlepaddle.org/07.label_semantic_roles/)定义的标签集,如下是一个NER的标注结果示例:
<div align="center"> <p align="center">
<img src="images/ner_label_ins.png" width = "80%" align=center /><br> <img src="images/ner_label_ins.png" width = "80%" align=center /><br>
图1. BIO标注方法示例 图1. BIO标注方法示例
</div> </p>
根据序列标注结果可以直接得到实体边界和实体类别。类似的,分词、词性标注、语块识别、[语义角色标注](http://book.paddlepaddle.org/07.label_semantic_roles/index.cn.html)等任务都可通过序列标注来解决。使用神经网络模型解决问题的思路通常是:前层网络学习输入的特征表示,网络的最后一层在特征基础上完成最终的任务;对于序列标注问题,通常:使用基于RNN的网络结构学习特征,将学习到的特征接入CRF完成序列标注。实际上是将传统CRF中的线性模型换成了非线性神经网络。沿用CRF的出发点是:CRF使用句子级别的似然概率,能够更好的解决标记偏置问题[[2](#参考文献)]。本例也将基于此思路建立模型。虽然,这里以NER任务作为示例,但所给出的模型可以应用到其他各种序列标注任务中。 根据序列标注结果可以直接得到实体边界和实体类别。类似的,分词、词性标注、语块识别、[语义角色标注](http://book.paddlepaddle.org/07.label_semantic_roles/index.cn.html)等任务都可通过序列标注来解决。使用神经网络模型解决问题的思路通常是:前层网络学习输入的特征表示,网络的最后一层在特征基础上完成最终的任务;对于序列标注问题,通常:使用基于RNN的网络结构学习特征,将学习到的特征接入CRF完成序列标注。实际上是将传统CRF中的线性模型换成了非线性神经网络。沿用CRF的出发点是:CRF使用句子级别的似然概率,能够更好的解决标记偏置问题[[2](#参考文献)]。本例也将基于此思路建立模型。虽然,这里以NER任务作为示例,但所给出的模型可以应用到其他各种序列标注任务中。
......
...@@ -65,10 +65,10 @@ ...@@ -65,10 +65,10 @@
序列标注可以分为Sequence Classification、Segment Classification和Temporal Classification三类[[1](#参考文献)],本例只考虑Segment Classification,即对输入序列中的每个元素在输出序列中给出对应的标签。对于NER任务,由于需要标识边界,一般采用[BIO标注方法](http://book.paddlepaddle.org/07.label_semantic_roles/)定义的标签集,如下是一个NER的标注结果示例: 序列标注可以分为Sequence Classification、Segment Classification和Temporal Classification三类[[1](#参考文献)],本例只考虑Segment Classification,即对输入序列中的每个元素在输出序列中给出对应的标签。对于NER任务,由于需要标识边界,一般采用[BIO标注方法](http://book.paddlepaddle.org/07.label_semantic_roles/)定义的标签集,如下是一个NER的标注结果示例:
<div align="center"> <p align="center">
<img src="images/ner_label_ins.png" width = "80%" align=center /><br> <img src="images/ner_label_ins.png" width = "80%" align=center /><br>
图1. BIO标注方法示例 图1. BIO标注方法示例
</div> </p>
根据序列标注结果可以直接得到实体边界和实体类别。类似的,分词、词性标注、语块识别、[语义角色标注](http://book.paddlepaddle.org/07.label_semantic_roles/index.cn.html)等任务都可通过序列标注来解决。使用神经网络模型解决问题的思路通常是:前层网络学习输入的特征表示,网络的最后一层在特征基础上完成最终的任务;对于序列标注问题,通常:使用基于RNN的网络结构学习特征,将学习到的特征接入CRF完成序列标注。实际上是将传统CRF中的线性模型换成了非线性神经网络。沿用CRF的出发点是:CRF使用句子级别的似然概率,能够更好的解决标记偏置问题[[2](#参考文献)]。本例也将基于此思路建立模型。虽然,这里以NER任务作为示例,但所给出的模型可以应用到其他各种序列标注任务中。 根据序列标注结果可以直接得到实体边界和实体类别。类似的,分词、词性标注、语块识别、[语义角色标注](http://book.paddlepaddle.org/07.label_semantic_roles/index.cn.html)等任务都可通过序列标注来解决。使用神经网络模型解决问题的思路通常是:前层网络学习输入的特征表示,网络的最后一层在特征基础上完成最终的任务;对于序列标注问题,通常:使用基于RNN的网络结构学习特征,将学习到的特征接入CRF完成序列标注。实际上是将传统CRF中的线性模型换成了非线性神经网络。沿用CRF的出发点是:CRF使用句子级别的似然概率,能够更好的解决标记偏置问题[[2](#参考文献)]。本例也将基于此思路建立模型。虽然,这里以NER任务作为示例,但所给出的模型可以应用到其他各种序列标注任务中。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册