From 34654503aadad1efb4d9de96eb016381c62407db Mon Sep 17 00:00:00 2001 From: Superjom Date: Sun, 9 Jul 2017 17:25:04 +0800 Subject: [PATCH] fix regression bug --- dssm/network_conf.py | 41 ++++++++++++++++++----------- dssm/reader.py | 28 +++++++++++++++++--- dssm/train.py | 18 ++++++++----- sequence_tagging_for_ner/README.md | 4 +-- sequence_tagging_for_ner/index.html | 4 +-- 5 files changed, 67 insertions(+), 28 deletions(-) diff --git a/dssm/network_conf.py b/dssm/network_conf.py index e88152f3..91607982 100644 --- a/dssm/network_conf.py +++ b/dssm/network_conf.py @@ -64,12 +64,14 @@ class DSSM(object): 'rank': self._build_rank_model, 'regression': self._build_regression_model, } + print 'model type: ', str(self.model_type) self.model_type_creater = _model_type[str(self.model_type)] def __call__(self): - if self.model_type.is_classification(): - return self._build_classification_model() - return self._build_rank_model() + # if self.model_type.is_classification(): + # return self._build_classification_model() + # return self._build_rank_model() + return self.model_type_creater() def create_embedding(self, input, prefix=''): ''' @@ -155,10 +157,14 @@ class DSSM(object): return _input_layer def _build_classification_model(self): + logger.info("build classification model") + assert self.model_type.is_classification() return self._build_classification_or_regression_model( is_classification=True) def _build_regression_model(self): + logger.info("build regression model") + assert self.model_type.is_regression() return self._build_classification_or_regression_model( is_classification=False) @@ -172,6 +178,8 @@ class DSSM(object): - right_target sentence - label, 1 if left_target should be sorted in front of right_target, otherwise 0. ''' + logger.info("build rank model") + assert self.model_type.is_rank() source = paddle.layer.data( name='source_input', type=paddle.data_type.integer_value_sequence(self.vocab_sizes[0])) @@ -221,8 +229,9 @@ class DSSM(object): - classification label ''' - # prepare inputs. - assert self.class_num + if is_classification: + # prepare inputs. + assert self.class_num source = paddle.layer.data( name='source_input', @@ -233,7 +242,7 @@ class DSSM(object): label = paddle.layer.data( name='label_input', type=paddle.data_type.integer_value(self.class_num) - if is_classification else paddle.data_type.dense_input) + if is_classification else paddle.data_type.dense_vector(1)) prefixs = '_ _'.split( ) if self.share_semantic_generator else 'left right'.split() @@ -250,15 +259,17 @@ class DSSM(object): x = self.model_arch_creater(input, prefix=prefixs[id]) semantics.append(x) - concated_vector = paddle.layer.concat(semantics) - prediction = paddle.layer.fc( - input=concated_vector, - size=self.class_num, - act=paddle.activation.Softmax()) - cost = paddle.layer.classification_cost( - input=prediction, - label=label) if is_classification else paddle.layer.mse_cost( - prediction, label) + if is_classification: + concated_vector = paddle.layer.concat(semantics) + prediction = paddle.layer.fc( + input=concated_vector, + size=self.class_num, + act=paddle.activation.Softmax()) + cost = paddle.layer.classification_cost( + input=prediction, label=label) + else: + prediction = paddle.layer.cos_sim(*semantics) + cost = paddle.layer.mse_cost(prediction, label) return cost, prediction, label diff --git a/dssm/reader.py b/dssm/reader.py index d69d88ec..8664c98d 100644 --- a/dssm/reader.py +++ b/dssm/reader.py @@ -15,9 +15,14 @@ class Dataset(object): self.source_dic = load_dic(self.source_dic_path) self.target_dic = load_dic(self.target_dic_path) - self.record_reader = self._read_classification_record \ - if self.model_type.is_classification() \ - else self._read_rank_record + _record_reader = { + ModelType.CLASSIFICATION_MODE: self._read_classification_record, + ModelType.REGRESSION_MODE: self._read_regression_record, + ModelType.RANK_MODE: self._read_rank_record, + } + + assert isinstance(model_type, ModelType) + self.record_reader = _record_reader[model_type.mode] def train(self): ''' @@ -54,6 +59,23 @@ class Dataset(object): label = int(fs[2]) return (source, target, label, ) + def _read_regression_record(self, line): + ''' + data format: + [TAB] [TAB]