提交 6d17432e 编写于 作者: M mapingshuo

add cdssm and decatt model for text_matching_on_quora

上级 a5bfd0cd
# Text matching on Quora qestion-answer pair dataset
# first commit
## Environment Preparation
### install python2
TODO
### Install fluid 0.15.0
TODO
## Prepare Data
Please download the Quora dataset firstly from https://drive.google.com/file/d/0B0PlTAo--BnaQWlsZl9FZ3l1c28/view?usp=sharing
to ROOT_DIR $HOME/.cache/paddle/dataset
Then run the data/prepare_quora_data.sh to download the pretrained embedding glove.840B.300d.zip:
```shell
cd data
sh prepare_quora_data.sh
```
The finally dataset dir should be like
```shell
$HOME/.cache/paddle/dataset
|- Quora_question_pair_partition
|- train.tsv
|- test.tsv
|- dev.tsv
|- readme.txt
|- wordvec.txt
|- glove.840B.300d.txt
```
### Train
```shell
fluid train_and_evaluate.py \
--model_name=cdssmNet \
--config=cdssm_base
```
You are supposed to get log like this:
```shell
net_name: cdssmNet
config {'save_dirname': 'cdssm_model', 'optimizer_type': 'adam', 'duplicate_data': False, 'train_samples_num': 384348, 'droprate_fc': 0.1, 'fc_dim': 128, 'kernel_count': 300, 'mlp_hid_dim': [128, 128], 'OOV_fill': 'uniform', 'class_dim': 2, 'epoch_num': 50, 'lr_decay': 1, 'learning_rate': 0.001, 'batch_size': 128, 'use_lod_tensor': True, 'metric_type': ['accuracy'], 'embedding_norm': False, 'emb_dim': 300, 'droprate_conv': 0.1, 'use_pretrained_word_embedding': True, 'kernel_size': 5, 'dict_dim': 40000}
Generating word dict...
('Vocab size: ', 36057)
loading word2vec from data/glove.840B.300d.txt
preparing pretrained word embedding ...
pretrained_word_embedding to be load: [[-0.086864 0.19161 0.10915 ... -0.01516 0.11108
0.2065 ]
[ 0.27204 -0.06203 -0.1884 ... 0.13015 -0.18317
0.1323 ]
[-0.20628 0.36716 -0.071933 ... 0.14271 0.50059
0.038025 ]
...
[-0.0387745 0.03030911 -0.01028247 ... -0.03096982 -0.01002833
0.04407753]
[-0.02707165 -0.04616793 -0.0260934 ... -0.00642176 0.02934359
0.02570623]
[ 0.00578131 0.0343625 -0.02623712 ... -0.04737288 0.01997969
0.04304557]]
param name: emb.w; param shape: (40000L, 300L)
param name: conv1d.w; param shape: (1500L, 300L)
param name: fc1.w; param shape: (300L, 128L)
param name: fc1.b; param shape: (128L,)
param name: fc_2.w_0; param shape: (256L, 128L)
param name: fc_2.b_0; param shape: (128L,)
param name: fc_3.w_0; param shape: (128L, 128L)
param name: fc_3.b_0; param shape: (128L,)
param name: fc_4.w_0; param shape: (128L, 2L)
param name: fc_4.b_0; param shape: (2L,)
loading pretrained word embedding to param
[Tue Oct 9 12:48:35 2018] epoch_id: -1, dev_cost: 0.796980, accuracy: 0.5
[Tue Oct 9 12:48:36 2018] epoch_id: -1, test_cost: 0.796876, accuracy: 0.5
[Tue Oct 9 12:48:36 2018] Start Training
[Tue Oct 9 12:48:44 2018] epoch_id: 0, batch_id: 0, cost: 0.878309, acc: 0.398438
[Tue Oct 9 12:48:46 2018] epoch_id: 0, batch_id: 100, cost: 0.607255, acc: 0.664062
[Tue Oct 9 12:48:48 2018] epoch_id: 0, batch_id: 200, cost: 0.521560, acc: 0.765625
[Tue Oct 9 12:48:51 2018] epoch_id: 0, batch_id: 300, cost: 0.512380, acc: 0.734375
[Tue Oct 9 12:48:54 2018] epoch_id: 0, batch_id: 400, cost: 0.522022, acc: 0.703125
[Tue Oct 9 12:48:56 2018] epoch_id: 0, batch_id: 500, cost: 0.470358, acc: 0.789062
[Tue Oct 9 12:48:58 2018] epoch_id: 0, batch_id: 600, cost: 0.561773, acc: 0.695312
[Tue Oct 9 12:49:01 2018] epoch_id: 0, batch_id: 700, cost: 0.485580, acc: 0.742188
[Tue Oct 9 12:49:03 2018] epoch_id: 0, batch_id: 800, cost: 0.493103, acc: 0.765625
[Tue Oct 9 12:49:05 2018] epoch_id: 0, batch_id: 900, cost: 0.388173, acc: 0.804688
[Tue Oct 9 12:49:08 2018] epoch_id: 0, batch_id: 1000, cost: 0.511332, acc: 0.742188
[Tue Oct 9 12:49:10 2018] epoch_id: 0, batch_id: 1100, cost: 0.488231, acc: 0.734375
[Tue Oct 9 12:49:12 2018] epoch_id: 0, batch_id: 1200, cost: 0.438371, acc: 0.781250
[Tue Oct 9 12:49:15 2018] epoch_id: 0, batch_id: 1300, cost: 0.535350, acc: 0.750000
[Tue Oct 9 12:49:17 2018] epoch_id: 0, batch_id: 1400, cost: 0.459860, acc: 0.773438
[Tue Oct 9 12:49:19 2018] epoch_id: 0, batch_id: 1500, cost: 0.382312, acc: 0.796875
[Tue Oct 9 12:49:22 2018] epoch_id: 0, batch_id: 1600, cost: 0.480827, acc: 0.742188
[Tue Oct 9 12:49:24 2018] epoch_id: 0, batch_id: 1700, cost: 0.474005, acc: 0.789062
[Tue Oct 9 12:49:26 2018] epoch_id: 0, batch_id: 1800, cost: 0.421068, acc: 0.789062
[Tue Oct 9 12:49:28 2018] epoch_id: 0, batch_id: 1900, cost: 0.420553, acc: 0.789062
[Tue Oct 9 12:49:31 2018] epoch_id: 0, batch_id: 2000, cost: 0.458412, acc: 0.781250
[Tue Oct 9 12:49:33 2018] epoch_id: 0, batch_id: 2100, cost: 0.360774, acc: 0.859375
[Tue Oct 9 12:49:35 2018] epoch_id: 0, batch_id: 2200, cost: 0.361226, acc: 0.835938
[Tue Oct 9 12:49:37 2018] epoch_id: 0, batch_id: 2300, cost: 0.371504, acc: 0.843750
[Tue Oct 9 12:49:40 2018] epoch_id: 0, batch_id: 2400, cost: 0.449930, acc: 0.804688
[Tue Oct 9 12:49:42 2018] epoch_id: 0, batch_id: 2500, cost: 0.442774, acc: 0.828125
[Tue Oct 9 12:49:44 2018] epoch_id: 0, batch_id: 2600, cost: 0.471352, acc: 0.781250
[Tue Oct 9 12:49:46 2018] epoch_id: 0, batch_id: 2700, cost: 0.344527, acc: 0.875000
[Tue Oct 9 12:49:48 2018] epoch_id: 0, batch_id: 2800, cost: 0.450750, acc: 0.765625
[Tue Oct 9 12:49:51 2018] epoch_id: 0, batch_id: 2900, cost: 0.459296, acc: 0.835938
[Tue Oct 9 12:49:53 2018] epoch_id: 0, batch_id: 3000, cost: 0.495118, acc: 0.742188
[Tue Oct 9 12:49:53 2018] epoch_id: 0, train_avg_cost: 0.457090, train_avg_acc: 0.779325
[Tue Oct 9 12:49:54 2018] epoch_id: 0, dev_cost: 0.439462, accuracy: 0.7865
[Tue Oct 9 12:49:55 2018] epoch_id: 0, test_cost: 0.441658, accuracy: 0.7867
[Tue Oct 9 12:50:04 2018] epoch_id: 1, batch_id: 0, cost: 0.320335, acc: 0.843750
[Tue Oct 9 12:50:06 2018] epoch_id: 1, batch_id: 100, cost: 0.398587, acc: 0.820312
[Tue Oct 9 12:50:08 2018] epoch_id: 1, batch_id: 200, cost: 0.324227, acc: 0.843750
[Tue Oct 9 12:50:11 2018] epoch_id: 1, batch_id: 300, cost: 0.303423, acc: 0.890625
[Tue Oct 9 12:50:13 2018] epoch_id: 1, batch_id: 400, cost: 0.438270, acc: 0.812500
[Tue Oct 9 12:50:15 2018] epoch_id: 1, batch_id: 500, cost: 0.307846, acc: 0.828125
[Tue Oct 9 12:50:19 2018] epoch_id: 1, batch_id: 600, cost: 0.338888, acc: 0.851562
[Tue Oct 9 12:50:21 2018] epoch_id: 1, batch_id: 700, cost: 0.341852, acc: 0.843750
[Tue Oct 9 12:50:23 2018] epoch_id: 1, batch_id: 800, cost: 0.365191, acc: 0.820312
[Tue Oct 9 12:50:25 2018] epoch_id: 1, batch_id: 900, cost: 0.464820, acc: 0.804688
[Tue Oct 9 12:50:28 2018] epoch_id: 1, batch_id: 1000, cost: 0.348680, acc: 0.851562
[Tue Oct 9 12:50:30 2018] epoch_id: 1, batch_id: 1100, cost: 0.390921, acc: 0.828125
[Tue Oct 9 12:50:32 2018] epoch_id: 1, batch_id: 1200, cost: 0.361488, acc: 0.820312
[Tue Oct 9 12:50:35 2018] epoch_id: 1, batch_id: 1300, cost: 0.324751, acc: 0.851562
[Tue Oct 9 12:50:37 2018] epoch_id: 1, batch_id: 1400, cost: 0.428706, acc: 0.804688
[Tue Oct 9 12:50:39 2018] epoch_id: 1, batch_id: 1500, cost: 0.504243, acc: 0.742188
[Tue Oct 9 12:50:42 2018] epoch_id: 1, batch_id: 1600, cost: 0.322159, acc: 0.851562
[Tue Oct 9 12:50:44 2018] epoch_id: 1, batch_id: 1700, cost: 0.451969, acc: 0.757812
[Tue Oct 9 12:50:46 2018] epoch_id: 1, batch_id: 1800, cost: 0.298705, acc: 0.890625
[Tue Oct 9 12:50:49 2018] epoch_id: 1, batch_id: 1900, cost: 0.439283, acc: 0.789062
[Tue Oct 9 12:50:51 2018] epoch_id: 1, batch_id: 2000, cost: 0.325409, acc: 0.851562
[Tue Oct 9 12:50:53 2018] epoch_id: 1, batch_id: 2100, cost: 0.312230, acc: 0.875000
[Tue Oct 9 12:50:56 2018] epoch_id: 1, batch_id: 2200, cost: 0.352170, acc: 0.843750
[Tue Oct 9 12:50:58 2018] epoch_id: 1, batch_id: 2300, cost: 0.366158, acc: 0.828125
[Tue Oct 9 12:51:00 2018] epoch_id: 1, batch_id: 2400, cost: 0.349191, acc: 0.812500
[Tue Oct 9 12:51:02 2018] epoch_id: 1, batch_id: 2500, cost: 0.391564, acc: 0.835938
[Tue Oct 9 12:51:05 2018] epoch_id: 1, batch_id: 2600, cost: 0.347518, acc: 0.835938
[Tue Oct 9 12:51:07 2018] epoch_id: 1, batch_id: 2700, cost: 0.279777, acc: 0.914062
[Tue Oct 9 12:51:09 2018] epoch_id: 1, batch_id: 2800, cost: 0.293878, acc: 0.851562
[Tue Oct 9 12:51:11 2018] epoch_id: 1, batch_id: 2900, cost: 0.367596, acc: 0.843750
[Tue Oct 9 12:51:13 2018] epoch_id: 1, batch_id: 3000, cost: 0.433259, acc: 0.804688
[Tue Oct 9 12:51:14 2018] epoch_id: 1, train_avg_cost: 0.348265, train_avg_acc: 0.841591
[Tue Oct 9 12:51:15 2018] epoch_id: 1, dev_cost: 0.398465, accuracy: 0.8163
[Tue Oct 9 12:51:16 2018] epoch_id: 1, test_cost: 0.399254, accuracy: 0.8209
```
from cdssm import cdssm_base
from dec_att import decatt_glove
from __future__ import print_function
class config(object):
def __init__(self):
self.batch_size = 128
self.epoch_num = 50
self.optimizer_type = 'adam' # sgd, adagrad
# pretrained word embedding
self.use_pretrained_word_embedding = True
# when employing pretrained word embedding,
# out of vocabulary words' embedding is initialized with uniform or normal numbers
self.OOV_fill = 'uniform'
self.embedding_norm = False
# or else, use padding and masks for sequence data
self.use_lod_tensor = True
# lr = lr * lr_decay after each epoch
self.lr_decay = 1
self.learning_rate = 0.001
self.save_dirname = 'model_dir'
self.train_samples_num = 384348
self.duplicate_data = False
self.metric_type = ['accuracy']
def list_config(self):
print("config", self.__dict__)
if __name__ == "__main__":
basic = config()
basic.list_config()
basic.ahh = 2
basic.list_config()
import basic_config
def cdssm_base():
"""
set configs
"""
config = basic_config.config()
config.learning_rate = 0.001
config.save_dirname = "cdssm_model"
config.use_pretrained_word_embedding = True
config.dict_dim = 40000 # approx_vocab_size
# net config
config.emb_dim = 300
config.kernel_size = 5
config.kernel_count = 300
config.fc_dim = 128
config.mlp_hid_dim = [128, 128]
config.droprate_conv = 0.1
config.droprate_fc = 0.1
config.class_dim = 2
return config
import basic_config
def decatt_glove():
"""
use config 'decAtt_glove' in the paper 'Neural Paraphrase Identification of Questions with Noisy Pretraining'
"""
config = basic_config.config()
config.learning_rate = 0.05
config.save_dirname = "decatt_model"
config.use_pretrained_word_embedding = True
config.dict_dim = 40000 # approx_vocab_size
config.metric_type = ['accuracy', 'accuracy_with_threshold']
config.optimizer_type = 'sgd'
config.lr_decay = 1
config.use_lod_tensor = False
config.embedding_norm = False
config.OOV_fill = 'uniform'
config.duplicate_data = False
# net config
config.emb_dim = 300
config.proj_emb_dim = 200 #TODO: has project?
config.num_units = [400, 200]
config.word_embedding_trainable = True
config.droprate = 0.1
config.share_wight_btw_seq = True
config.class_dim = 2
return config
# Please download the Quora dataset firstly from https://drive.google.com/file/d/0B0PlTAo--BnaQWlsZl9FZ3l1c28/view?usp=sharing
# to the ROOT_DIR: $HOME/.cache/paddle/dataset
DATA_DIR=$HOME/.cache/paddle/dataset
wget --directory-prefix=$DATA_DIR http://nlp.stanford.edu/data/glove.840B.300d.zip
unzip $DATA_DIR/glove.840B.300d.zip
# The finally dataset dir should be like
# $HOME/.cache/paddle/dataset
# |- Quora_question_pair_partition
# |- train.tsv
# |- test.tsv
# |- dev.tsv
# |- readme.txt
# |- wordvec.txt
# |- glove.840B.300d.txt
"""
This Module provide different kinds of Match layers
"""
import paddle.fluid as fluid
import paddle.v2 as paddle
def MultiPerspectiveMatching(vec1, vec2, perspective_num):
"""
MultiPerspectiveMatching
"""
sim_res = None
for i in range(perspective_num):
vec1_res = fluid.layers.elementwise_add_with_weight(
vec1,
param_attr="elementwise_add_with_weight." + str(i))
vec2_res = fluid.layers.elementwise_add_with_weight(
vec2,
param_attr="elementwise_add_with_weight." + str(i))
m = fluid.layers.cos_sim(vec1_res, vec2_res)
if sim_res is None:
sim_res = m
else:
sim_res = fluid.layers.concat(input=[sim_res, m], axis=1)
return sim_res
def ConcateMatching(vec1, vec2):
"""
ConcateMatching
"""
#TODO: assert shape
return fluid.layers.concat(input=[vec1, vec2], axis=1)
def ElementwiseMatching(vec1, vec2):
"""
reference: [Supervised Learning of Universal Sentence Representations from Natural Language Inference Data](https://arxiv.org/abs/1705.02364)
"""
elementwise_mul = fluid.layers.elementwise_mul(x=vec1, y=vec2)
elementwise_sub = fluid.layers.elementwise_sub(x=vec1, y=vec2)
elementwise_abs_sub = fluid.layers.abs(elementwise_sub)
return fluid.layers.concat(input=[vec1, vec2, elementwise_mul, elementwise_abs_sub], axis=1)
def MultiPerspectiveFullMatching(seq1, seq2, perspective_num):
"""
seq1: Lod tensor with shape [-1, feature_dim] (lod level == 1) is a representation of a sentence.
seq2: Another Lod tensor with shape [-1, feature_dim] (lod level == 1) is a representation of a sentence.
use seq1 to match seq2
return match seq with same shape as seq1.
"""
print seq2
seq2_last = fluid.layers.sequence_pool(input=seq2, pool_type="last")
print seq2_last
seq2 = fluid.layers.sequence_expand(seq2_last, seq1)
print seq2
#seq2 = fluid.layers.lod_reset(x=seq2, y=seq1)
seq2.set_lod(seq1)
print seq2
import numpy as np
"""
This Module defines evaluate metrics for classification tasks
"""
def accuracy(y_pred, label):
"""
define correct: the top 1 class in y_pred is the same as y_true
"""
y_pred = np.squeeze(y_pred)
y_pred_idx = np.argmax(y_pred, axis=1)
return 1.0 * np.sum(y_pred_idx == label) / label.shape[0]
def accuracy_with_threshold(y_pred, label, threshold=0.5):
"""
define correct: the y_true class's prob in y_pred is bigger than threshold
when threshold is 0.5, This fuction is equal to accuracy
"""
y_pred = np.squeeze(y_pred)
y_pred_idx = (y_pred[:, 1] > threshold).astype(int)
return 1.0 * np.sum(y_pred_idx == label) / label.shape[0]
from cdssm import cdssmNet
from dec_att import DecAttNet
import paddle.fluid as fluid
class cdssmNet():
"""cdssm net"""
def __init__(self, config):
self._config = config
def __call__(self, seq1, seq2, label):
return self.body(seq1, seq2, label, self._config)
def body(self, seq1, seq2, label, config):
"""Body function"""
def conv_model(seq):
embed = fluid.layers.embedding(input=seq, size=[config.dict_dim, config.emb_dim], param_attr='emb.w')
conv = fluid.layers.sequence_conv(embed,
num_filters=config.kernel_count,
filter_size=config.kernel_size,
filter_stride=1,
padding=True, # TODO: what is padding
bias_attr=False,
param_attr='conv1d.w',
act='relu')
#print paddle.parameters.get('conv1d.w').shape
conv = fluid.layers.dropout(conv, dropout_prob = config.droprate_conv)
pool = fluid.layers.sequence_pool(conv, pool_type="max")
fc = fluid.layers.fc(pool,
size=config.fc_dim,
param_attr='fc1.w',
bias_attr='fc1.b',
act='relu')
return fc
def MLP(vec):
for dim in config.mlp_hid_dim:
vec = fluid.layers.fc(vec, size=dim, act='relu')
vec = fluid.layers.dropout(vec, dropout_prob=config.droprate_fc)
return vec
seq1_fc = conv_model(seq1)
seq2_fc = conv_model(seq2)
concated_seq = fluid.layers.concat(input=[seq1_fc, seq2_fc], axis=1)
mlp_res = MLP(concated_seq)
prediction = fluid.layers.fc(mlp_res, size=config.class_dim, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=loss)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
import paddle.fluid as fluid
class DecAttNet():
"""decompose attention net"""
def __init__(self, config):
self._config = config
self.initializer = fluid.initializer.Xavier(uniform=False)
def __call__(self, seq1, seq2, mask1, mask2, label):
return self.body(seq1, seq2, mask1, mask2, label)
def body(self, seq1, seq2, mask1, mask2, label):
"""Body function"""
transformed_q1 = self.transformation(seq1)
transformed_q2 = self.transformation(seq2)
masked_q1 = self.apply_mask(transformed_q1, mask1)
masked_q2 = self.apply_mask(transformed_q2, mask2)
alpha, beta = self.attend(masked_q1, masked_q2)
if self._config.share_wight_btw_seq:
seq1_compare = self.compare(masked_q1, beta, param_prefix='compare')
seq2_compare = self.compare(masked_q2, alpha, param_prefix='compare')
else:
seq1_compare = self.compare(masked_q1, beta, param_prefix='compare_1')
seq2_compare = self.compare(masked_q2, alpha, param_prefix='compare_2')
aggregate_res = self.aggregate(seq1_compare, seq2_compare)
prediction = fluid.layers.fc(aggregate_res, size=self._config.class_dim, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=loss)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
def apply_mask(self, seq, mask):
"""
apply mask on seq
Input: seq in shape [batch_size, seq_len, embedding_size]
Input: mask in shape [batch_size, seq_len]
Output: masked seq in shape [batch_size, seq_len, embedding_size]
"""
return fluid.layers.elementwise_mul(x=seq, y=mask, axis=0)
def feed_forward_2d(self, vec, param_prefix):
"""
Input: vec in shape [batch_size, seq_len, vec_dim]
Output: fc2 in shape [batch_size, seq_len, num_units[1]]
"""
fc1 = fluid.layers.fc(vec, size=self._config.num_units[0], num_flatten_dims=2,
param_attr=fluid.ParamAttr(name=param_prefix+'_fc1.w',
initializer=self.initializer),
bias_attr=param_prefix + '_fc1.b', act='relu')
fc1 = fluid.layers.dropout(fc1, dropout_prob = self._config.droprate)
fc2 = fluid.layers.fc(fc1, size=self._config.num_units[1], num_flatten_dims=2,
param_attr=fluid.ParamAttr(name=param_prefix+'_fc2.w',
initializer=self.initializer),
bias_attr=param_prefix + '_fc2.b', act='relu')
fc2 = fluid.layers.dropout(fc2, dropout_prob = self._config.droprate)
return fc2
def feed_forward(self, vec, param_prefix):
"""
Input: vec in shape [batch_size, vec_dim]
Output: fc2 in shape [batch_size, num_units[1]]
"""
fc1 = fluid.layers.fc(vec, size=self._config.num_units[0], num_flatten_dims=1,
param_attr=fluid.ParamAttr(name=param_prefix+'_fc1.w',
initializer=self.initializer),
bias_attr=param_prefix + '_fc1.b', act='relu')
fc1 = fluid.layers.dropout(fc1, dropout_prob = self._config.droprate)
fc2 = fluid.layers.fc(fc1, size=self._config.num_units[1], num_flatten_dims=1,
param_attr=fluid.ParamAttr(name=param_prefix+'_fc2.w',
initializer=self.initializer),
bias_attr=param_prefix + '_fc2.b', act='relu')
fc2 = fluid.layers.dropout(fc2, dropout_prob = self._config.droprate)
return fc2
def transformation(self, seq):
embed = fluid.layers.embedding(input=seq, size=[self._config.dict_dim, self._config.emb_dim],
param_attr=fluid.ParamAttr(name='emb.w', trainable=self._config.word_embedding_trainable))
if self._config.proj_emb_dim is not None:
return fluid.layers.fc(embed, size=self._config.proj_emb_dim, num_flatten_dims=2,
param_attr=fluid.ParamAttr(name='project' + '_fc1.w',
initializer=self.initializer),
bias_attr=False,
act=None)
return embed
def attend(self, seq1, seq2):
"""
Input: seq1, shape [batch_size, seq_len1, embed_size]
Input: seq2, shape [batch_size, seq_len2, embed_size]
Output: alpha, shape [batch_size, seq_len1, embed_size]
Output: beta, shape [batch_size, seq_len2, embed_size]
"""
if self._config.share_wight_btw_seq:
seq1 = self.feed_forward_2d(seq1, param_prefix="attend")
seq2 = self.feed_forward_2d(seq2, param_prefix="attend")
else:
seq1 = self.feed_forward_2d(seq1, param_prefix="attend_1")
seq2 = self.feed_forward_2d(seq2, param_prefix="attend_2")
attention_weight = fluid.layers.matmul(seq1, seq2, transpose_y=True)
normalized_attention_weight = fluid.layers.softmax(attention_weight)
beta = fluid.layers.matmul(normalized_attention_weight, seq2)
attention_weight_t = fluid.layers.transpose(attention_weight, perm=[0, 2, 1])
normalized_attention_weight_t = fluid.layers.softmax(attention_weight_t)
alpha = fluid.layers.matmul(normalized_attention_weight_t, seq1)
return alpha, beta
def compare(self, seq, soft_alignment, param_prefix):
concat_seq = fluid.layers.concat(input=[seq, soft_alignment], axis=2)
return self.feed_forward_2d(concat_seq, param_prefix="compare")
def aggregate(self, vec1, vec2):
vec1 = fluid.layers.reduce_sum(vec1, dim=1)
vec2 = fluid.layers.reduce_sum(vec2, dim=1)
concat_vec = fluid.layers.concat(input=[vec1, vec2], axis=1)
return self.feed_forward(concat_vec, param_prefix='aggregate')
"""
This module defines some Frequently-used DNN layers
"""
import paddle.fluid as fluid
def bi_lstm_layer(input, rnn_hid_dim, name):
"""
This is a Bi-directional LSTM(long short term memory) Module
"""
fc0 = fluid.layers.fc(input=input, # fc for lstm
size=rnn_hid_dim * 4,
param_attr=name + '.fc0.w',
bias_attr=False,
act=None)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0,
size=rnn_hid_dim * 4,
is_reverse=False,
param_attr=name + '.lstm_w',
bias_attr=name + '.lstm_b')
reversed_lstm_h, reversed_c = fluid.layers.dynamic_lstm(
input=fc0,
size=rnn_hid_dim * 4,
is_reverse=True,
param_attr=name + '.reversed_lstm_w',
bias_attr=name + '.reversed_lstm_b')
return fluid.layers.concat(input=[lstm_h, reversed_lstm_h], axis=1)
"""
This Module provide pretrained word-embeddings
"""
from __future__ import print_function
import numpy as np
def Glove840B_300D(filepath="data/glove.840B.300d.txt"):
"""
input: the "glove.840B.300d.txt" file path
return: a dict, key: word (unicode), value: a numpy array with shape [300]
"""
print("loading word2vec from ", filepath)
word2vec = {}
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
info = line.strip().split()
word, vector = info[0], info[1:]
assert(len(vector) == 300)
#TODO: test python3
word2vec[word.decode('utf-8')] = np.asarray(vector, dtype='float32')
return word2vec
if __name__ == '__main__':
embed_dict = Glove840B_300D("data/glove.840B.300d.txt")
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
import paddle.dataset.common
import collections
import tarfile
import re
import string
import random
import os
import nltk
from os.path import expanduser
__all__ = ['word_dict', 'train', 'dev', 'test']
URL = "https://drive.google.com/file/d/0B0PlTAo--BnaQWlsZl9FZ3l1c28/view"
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
DATA_DIR = "Quora_question_pair_partition"
QUORA_TRAIN_FILE_NAME = os.path.join(DATA_HOME, DATA_DIR, 'train.tsv')
QUORA_DEV_FILE_NAME = os.path.join(DATA_HOME, DATA_DIR, 'dev.tsv')
QUORA_TEST_FILE_NAME = os.path.join(DATA_HOME, DATA_DIR, 'test.tsv')
# punctuation or nltk or space
TOKENIZE_METHOD='space'
COLUMN_COUNT = 4
def tokenize(s):
s = s.decode('utf-8')
if TOKENIZE_METHOD == "nltk":
return nltk.tokenize.word_tokenize(s)
elif TOKENIZE_METHOD == "punctuation":
return s.translate({ord(char): None for char in string.punctuation}).lower().split()
elif TOKENIZE_METHOD == "space":
return s.split()
else:
raise RuntimeError("Invalid tokenize method")
def maybe_open(file_name):
if not os.path.isfile(file_name):
msg = "file not exist: %s\nPlease download the dataset firstly from: %s\n\n" % (file_name, URL) + \
("# The finally dataset dir should be like\n\n"
"$HOME/.cache/paddle/dataset\n"
" |- Quora_question_pair_partition\n"
" |- train.tsv\n"
" |- test.tsv\n"
" |- dev.tsv\n"
" |- readme.txt\n"
" |- wordvec.txt\n")
raise RuntimeError(msg)
return open(file_name, 'r')
def tokenized_question_pairs(file_name):
"""
"""
with maybe_open(file_name) as f:
questions = {}
lines = f.readlines()
for line in lines:
info = line.strip().split('\t')
if len(info) != COLUMN_COUNT:
# formatting error
continue
(label, question1, question2, id) = info
question1 = tokenize(question1)
question2 = tokenize(question2)
yield question1, question2, int(label)
def tokenized_questions(file_name):
"""
"""
with maybe_open(file_name) as f:
lines = f.readlines()
for line in lines:
info = line.strip().split('\t')
if len(info) != COLUMN_COUNT:
# formatting error
continue
(label, question1, question2, id) = info
yield tokenize(question1)
yield tokenize(question2)
def build_dict(file_name, cutoff):
"""
Build a word dictionary from the corpus. Keys of the dictionary are words,
and values are zero-based IDs of these words.
"""
word_freq = collections.defaultdict(int)
for doc in tokenized_questions(file_name):
for word in doc:
word_freq[word] += 1
word_freq = filter(lambda x: x[1] > cutoff, word_freq.items())
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
word_idx = dict(zip(words, xrange(len(words))))
word_idx['<unk>'] = len(words)
word_idx['<pad>'] = len(words) + 1
return word_idx
def reader_creator(file_name, word_idx):
UNK_ID = word_idx['<unk>']
def reader():
for (q1, q2, label) in tokenized_question_pairs(file_name):
q1_ids = [word_idx.get(w, UNK_ID) for w in q1]
q2_ids = [word_idx.get(w, UNK_ID) for w in q2]
if q1_ids != [] and q2_ids != []: # [] is not allowed in fluid
assert(label in [0, 1])
yield q1_ids, q2_ids, label
return reader
def train(word_idx):
"""
Quora training set creator.
It returns a reader creator, each sample in the reader is two zero-based ID
list and label in [0, 1].
:param word_idx: word dictionary
:type word_idx: dict
:return: Training reader creator
:rtype: callable
"""
return reader_creator(QUORA_TRAIN_FILE_NAME, word_idx)
def dev(word_idx):
"""
Quora develop set creator.
It returns a reader creator, each sample in the reader is two zero-based ID
list and label in [0, 1].
:param word_idx: word dictionary
:type word_idx: dict
:return: develop reader creator
:rtype: callable
"""
return reader_creator(QUORA_DEV_FILE_NAME, word_idx)
def test(word_idx):
"""
Quora test set creator.
It returns a reader creator, each sample in the reader is two zero-based ID
list and label in [0, 1].
:param word_idx: word dictionary
:type word_idx: dict
:return: Test reader creator
:rtype: callable
"""
return reader_creator(QUORA_TEST_FILE_NAME, word_idx)
def word_dict():
"""
Build a word dictionary from the corpus.
:return: Word dictionary
:rtype: dict
"""
return build_dict(file_name=QUORA_TRAIN_FILE_NAME, cutoff=4)
from __future__ import print_function
import os
import sys
import time
import argparse
import unittest
import contextlib
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
import utils, metric, configs
import models
from pretrained_word2vec import Glove840B_300D
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--model_name', type=str, default='cdssm', help="Which model to train")
parser.add_argument('--config', type=str, default='cdssm.cdssm_base', help="The global config setting")
DATA_DIR = os.path.join(os.path.expanduser('~'), '.cache/paddle/dataset')
def evaluate(epoch_id, exe, inference_program, dev_reader, test_reader, fetch_list, feeder, metric_type):
"""
evaluate on test/dev dataset
"""
def infer(test_reader):
"""
do inference function
"""
total_cost = 0.0
total_count = 0
preds, labels = [], []
for data in test_reader():
avg_cost, avg_acc, batch_prediction = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=fetch_list,
return_numpy=True)
total_cost += avg_cost * len(data)
total_count += len(data)
preds.append(batch_prediction)
labels.append(np.asarray([x[-1] for x in data], dtype=np.int64))
y_pred = np.concatenate(preds)
y_label = np.concatenate(labels)
metric_res = []
for metric_name in metric_type:
if metric_name == 'accuracy_with_threshold':
metric_res.append((metric_name, metric.accuracy_with_threshold(y_pred, y_label, threshold=0.3)))
elif metric_name == 'accuracy':
metric_res.append((metric_name, metric.accuracy(y_pred, y_label)))
else:
print("Unknown metric type: ", metric_name)
exit()
return total_cost / (total_count * 1.0), metric_res
dev_cost, dev_metric_res = infer(dev_reader)
print("[%s] epoch_id: %d, dev_cost: %f, " % (
time.asctime( time.localtime(time.time()) ),
epoch_id,
dev_cost)
+ ', '.join([str(x[0]) + ": " + str(x[1]) for x in dev_metric_res]))
test_cost, test_metric_res = infer(test_reader)
print("[%s] epoch_id: %d, test_cost: %f, " % (
time.asctime( time.localtime(time.time()) ),
epoch_id,
test_cost)
+ ', '.join([str(x[0]) + ": " + str(x[1]) for x in test_metric_res]))
print("")
def train_and_evaluate(train_reader,
test_reader,
dev_reader,
network,
optimizer,
global_config,
pretrained_word_embedding,
use_cuda,
parallel):
"""
train network
"""
# define the net
if global_config.use_lod_tensor:
# automatic add batch dim
q1 = fluid.layers.data(
name="question1", shape=[1], dtype="int64", lod_level=1)
q2 = fluid.layers.data(
name="question2", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
cost, acc, prediction = network(q1, q2, label)
else:
# shape: [batch_size, max_seq_len_in_batch, 1]
q1 = fluid.layers.data(
name="question1", shape=[-1, -1, 1], dtype="int64")
q2 = fluid.layers.data(
name="question2", shape=[-1, -1, 1], dtype="int64")
# shape: [batch_size, max_seq_len_in_batch]
mask1 = fluid.layers.data(name="mask1", shape=[-1, -1], dtype="float32")
mask2 = fluid.layers.data(name="mask2", shape=[-1, -1], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
cost, acc, prediction = network(q1, q2, mask1, mask2, label)
if parallel:
# TODO: Paarallel Training
print("Parallel Training is not supported for now.")
sys.exit(1)
optimizer.minimize(cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
if global_config.use_lod_tensor:
feeder = fluid.DataFeeder(feed_list=[q1, q2, label], place=place)
else:
feeder = fluid.DataFeeder(feed_list=[q1, q2, mask1, mask2, label], place=place)
# logging param info
for param in fluid.default_main_program().global_block().all_parameters():
print("param name: %s; param shape: %s" % (param.name, param.shape))
# define inference_program
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program([cost, acc])
exe.run(fluid.default_startup_program())
# load emb from a numpy erray
if pretrained_word_embedding is not None:
print("loading pretrained word embedding to param")
embedding_name = "emb.w"
embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor()
embedding_param.set(pretrained_word_embedding, place)
evaluate(-1,
exe,
inference_program,
dev_reader,
test_reader,
fetch_list=[cost, acc, prediction],
feeder=feeder,
metric_type=global_config.metric_type)
# start training
print("[%s] Start Training" % time.asctime(time.localtime(time.time())))
for epoch_id in xrange(global_config.epoch_num):
data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0
batch_id = 0
for data in train_reader():
avg_cost_np, avg_acc_np = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[cost, acc])
data_size = len(data)
total_acc += data_size * avg_acc_np
total_cost += data_size * avg_cost_np
data_count += data_size
if batch_id % 100 == 0:
print("[%s] epoch_id: %d, batch_id: %d, cost: %f, acc: %f" % (
time.asctime(time.localtime(time.time())),
epoch_id,
batch_id,
avg_cost_np,
avg_acc_np))
batch_id += 1
avg_cost = total_cost / data_count
avg_acc = total_acc / data_count
print("")
print("[%s] epoch_id: %d, train_avg_cost: %f, train_avg_acc: %f" % (
time.asctime( time.localtime(time.time()) ), epoch_id, avg_cost, avg_acc))
epoch_model = global_config.save_dirname + "/" + "epoch" + str(epoch_id)
fluid.io.save_inference_model(epoch_model, ["question1", "question2", "label"], acc, exe)
evaluate(epoch_id,
exe,
inference_program,
dev_reader,
test_reader,
fetch_list=[cost, acc, prediction],
feeder=feeder,
metric_type=global_config.metric_type)
def main():
"""
This function will parse argments, prepare data and prepare pretrained embedding
"""
args = parser.parse_args()
global_config = configs.__dict__[args.config]()
print("net_name: ", args.model_name)
net = models.__dict__[args.model_name](global_config)
global_config.list_config()
# get word_dict
word_dict = utils.getDict(data_type="quora_question_pairs")
# get reader
train_reader, dev_reader, test_reader = utils.prepare_data(
"quora_question_pairs",
word_dict=word_dict,
batch_size = global_config.batch_size,
buf_size=800000,
duplicate_data=global_config.duplicate_data,
use_pad=(not global_config.use_lod_tensor))
# load pretrained_word_embedding
if global_config.use_pretrained_word_embedding:
word2vec = Glove840B_300D(filepath=os.path.join(DATA_DIR, "glove.840B.300d.txt"))
pretrained_word_embedding = utils.get_pretrained_word_embedding(
word2vec=word2vec,
word2id=word_dict,
config=global_config)
print("pretrained_word_embedding to be load:", pretrained_word_embedding)
else:
pretrained_word_embedding = None
# define optimizer
optimizer = utils.getOptimizer(global_config)
train_and_evaluate(
train_reader,
dev_reader,
test_reader,
net,
optimizer,
global_config,
pretrained_word_embedding,
use_cuda=True,
parallel=False)
if __name__ == "__main__":
main()
"""
This module provides utilities for data generator and optimizer definition
"""
import sys
import time
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
import quora_question_pairs
def to_lodtensor(data, place):
"""
convert to LODtensor
"""
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def getOptimizer(global_config):
"""
get Optimizer by config
"""
if global_config.optimizer_type == "adam":
optimizer = fluid.optimizer.Adam(learning_rate=fluid.layers.exponential_decay(
learning_rate=global_config.learning_rate,
decay_steps=global_config.train_samples_num // global_config.batch_size,
decay_rate=global_config.lr_decay))
elif global_config.optimizer_type == "sgd":
optimizer = fluid.optimizer.SGD(learning_rate=fluid.layers.exponential_decay(
learning_rate=global_config.learning_rate,
decay_steps=global_config.train_samples_num // global_config.batch_size,
decay_rate=global_config.lr_decay))
elif global_config.optimizer_type == "adagrad":
optimizer = fluid.optimizer.Adagrad(learning_rate=fluid.layers.exponential_decay(
learning_rate=global_config.learning_rate,
decay_steps=global_config.train_samples_num // global_config.batch_size,
decay_rate=global_config.lr_decay))
return optimizer
def get_pretrained_word_embedding(word2vec, word2id, config):
"""get pretrained embedding in shape [config.dict_dim, config.emb_dim]"""
print("preparing pretrained word embedding ...")
assert(config.dict_dim >= len(word2id))
word2id = sorted(word2id.items(), key = lambda x : x[1])
words = [x[0] for x in word2id]
words = words + ['<not-a-real-words>'] * (config.dict_dim - len(words))
pretrained_emb = []
for _, word in enumerate(words):
if word in word2vec:
assert(len(word2vec[word] == config.emb_dim))
if config.embedding_norm:
pretrained_emb.append(word2vec[word] / np.linalg.norm(word2vec[word]))
else:
pretrained_emb.append(word2vec[word])
elif config.OOV_fill == 'uniform':
pretrained_emb.append(np.random.uniform(-0.05, 0.05, size=[config.emb_dim]).astype(np.float32))
elif config.OOV_fill == 'normal':
pretrained_emb.append(np.random.normal(loc=0.0, scale=0.1, size=[config.emb_dim]).astype(np.float32))
else:
print("Unkown OOV fill method: ", OOV_fill)
exit()
word_embedding = np.stack(pretrained_emb)
return word_embedding
def getDict(data_type="quora_question_pairs"):
"""
get word2id dict from quora dataset
"""
print("Generating word dict...")
if data_type == "quora_question_pairs":
word_dict = quora_question_pairs.word_dict()
else:
raise RuntimeError("No such dataset")
print("Vocab size: ", len(word_dict))
return word_dict
def duplicate(reader):
"""
duplicate the quora qestion pairs since there are 2 questions in a sample
Input: reader, which yield (question1, question2, label)
Output: reader, which yield (question1, question2, label) and yield (question2, question1, label)
"""
def duplicated_reader():
for data in reader():
(q1, q2, label) = data
yield (q1, q2, label)
yield (q2, q1, label)
return duplicated_reader
def pad(reader, PAD_ID):
"""
Input: reader, yield batches of [(question1, question2, label), ... ]
Output: padded_reader, yield batches of [(padded_question1, padded_question2, mask1, mask2, label), ... ]
"""
assert(isinstance(PAD_ID, int))
def padded_reader():
for batch in reader():
max_len1 = max([len(data[0]) for data in batch])
max_len2 = max([len(data[1]) for data in batch])
padded_batch = []
for data in batch:
question1, question2, label = data
seq_len1 = len(question1)
seq_len2 = len(question2)
mask1 = [1] * seq_len1 + [0] * (max_len1 - seq_len1)
mask2 = [1] * seq_len2 + [0] * (max_len2 - seq_len2)
padded_question1 = question1 + [PAD_ID] * (max_len1 - seq_len1)
padded_question2 = question2 + [PAD_ID] * (max_len2 - seq_len2)
padded_question1 = [[x] for x in padded_question1] # last dim of questions must be 1, according to fluid's request
padded_question2 = [[x] for x in padded_question2]
assert(len(mask1) == max_len1)
assert(len(mask2) == max_len2)
assert(len(padded_question1) == max_len1)
assert(len(padded_question2) == max_len2)
padded_batch.append((padded_question1, padded_question2, mask1, mask2, label))
yield padded_batch
return padded_reader
def prepare_data(data_type,
word_dict,
batch_size,
buf_size=50000,
duplicate_data=False,
use_pad=False):
"""
prepare data
"""
PAD_ID=word_dict['<pad>']
if data_type == "quora_question_pairs":
# train/dev/test reader are batched iters which yield a batch of (question1, question2, label) each time
# qestion1 and question2 are lists of word ID
# label is 0 or 1
# for example: ([1, 3, 2], [7, 5, 4, 99], 1)
def prepare_reader(reader):
if duplicate_data:
reader = duplicate(reader)
reader = paddle.batch(
paddle.reader.shuffle(reader, buf_size=buf_size),
batch_size=batch_size,
drop_last=False)
if use_pad:
reader = pad(reader, PAD_ID=PAD_ID)
return reader
train_reader = prepare_reader(quora_question_pairs.train(word_dict))
dev_reader = prepare_reader(quora_question_pairs.dev(word_dict))
test_reader = prepare_reader(quora_question_pairs.test(word_dict))
else:
raise RuntimeError("no such dataset")
return train_reader, dev_reader, test_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册