提交 e19f4bc7 编写于 作者: D Dilyar 提交者: Yibing Liu

Fix some problems of simnet (#3433)

* update

* update

* Update README.md

* Update run.sh
上级 107d4e79
......@@ -49,7 +49,7 @@ class BOW(object):
right_soft = softsign_layer.ops(right_pool)
# matching layer
if self.task_mode == "pairwise":
bow_layer = layers.FCLayer(self.bow_dim, "relu", "fc")
bow_layer = layers.FCLayer(self.bow_dim, None, "fc")
left_bow = bow_layer.ops(left_soft)
right_bow = bow_layer.ops(right_soft)
cos_sim_layer = layers.CosSimLayer()
......@@ -58,7 +58,7 @@ class BOW(object):
else:
concat_layer = layers.ConcatLayer(1)
concat = concat_layer.ops([left_soft, right_soft])
bow_layer = layers.FCLayer(self.bow_dim, "relu", "fc")
bow_layer = layers.FCLayer(self.bow_dim, None, "fc")
concat_fc = bow_layer.ops(concat)
softmax_layer = layers.FCLayer(2, "softmax", "cos_sim")
pred = softmax_layer.ops(concat_fc)
......
......@@ -43,23 +43,23 @@ class CNN(object):
left_emb = emb_layer.ops(left)
right_emb = emb_layer.ops(right)
# Presentation context
cnn_layer = layers.SequenceConvPoolLayer(self.filter_size,
self.num_filters, "conv")
cnn_layer = layers.SequenceConvPoolLayer(
self.filter_size, self.num_filters, "conv")
left_cnn = cnn_layer.ops(left_emb)
right_cnn = cnn_layer.ops(right_emb)
# matching layer
if self.task_mode == "pairwise":
relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu")
left_relu = relu_layer.ops(left_cnn)
right_relu = relu_layer.ops(right_cnn)
fc_layer = layers.FCLayer(self.hidden_dim, None, "fc")
left_fc = fc_layer.ops(left_cnn)
right_fc = fc_layer.ops(right_cnn)
cos_sim_layer = layers.CosSimLayer()
pred = cos_sim_layer.ops(left_relu, right_relu)
return left_relu, pred
pred = cos_sim_layer.ops(left_fc, right_fc)
return left_fc, pred
else:
concat_layer = layers.ConcatLayer(1)
concat = concat_layer.ops([left_cnn, right_cnn])
relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu")
concat_fc = relu_layer.ops(concat)
fc_layer = layers.FCLayer(self.hidden_dim, None, "fc")
concat_fc = fc_layer.ops(concat)
softmax_layer = layers.FCLayer(2, "softmax", "cos_sim")
pred = softmax_layer.ops(concat_fc)
return left_cnn, pred
......@@ -50,17 +50,17 @@ class GRU(object):
right_last = last_layer.ops(right_gru)
# matching layer
if self.task_mode == "pairwise":
relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu")
left_relu = relu_layer.ops(left_last)
right_relu = relu_layer.ops(right_last)
fc_layer = layers.FCLayer(self.hidden_dim, None, "fc")
left_fc = fc_layer.ops(left_last)
right_fc = fc_layer.ops(right_last)
cos_sim_layer = layers.CosSimLayer()
pred = cos_sim_layer.ops(left_relu, right_relu)
return left_relu, pred
pred = cos_sim_layer.ops(left_fc, right_fc)
return left_fc, pred
else:
concat_layer = layers.ConcatLayer(1)
concat = concat_layer.ops([left_last, right_last])
relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu")
concat_fc = relu_layer.ops(concat)
fc_layer = layers.FCLayer(self.hidden_dim, None, "fc")
concat_fc = fc_layer.ops(concat)
softmax_layer = layers.FCLayer(2, "softmax", "cos_sim")
pred = softmax_layer.ops(concat_fc)
return left_last, pred
......@@ -49,17 +49,17 @@ class LSTM(object):
right_last = last_layer.ops(right_lstm)
# matching layer
if self.task_mode == "pairwise":
relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu")
left_relu = relu_layer.ops(left_last)
right_relu = relu_layer.ops(right_last)
fc_layer = layers.FCLayer(self.hidden_dim, None, "fc")
left_fc = fc_layer.ops(left_last)
right_fc = fc_layer.ops(right_last)
cos_sim_layer = layers.CosSimLayer()
pred = cos_sim_layer.ops(left_relu, right_relu)
return left_relu, pred
pred = cos_sim_layer.ops(left_fc, right_fc)
return left_fc, pred
else:
concat_layer = layers.ConcatLayer(1)
concat = concat_layer.ops([left_last, right_last])
relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu")
concat_fc = relu_layer.ops(concat)
fc_layer = layers.FCLayer(self.hidden_dim, None, "fc")
concat_fc = fc_layer.ops(concat)
softmax_layer = layers.FCLayer(2, "softmax", "cos_sim")
pred = softmax_layer.ops(concat_fc)
return left_last, pred
......@@ -6,10 +6,17 @@
基于百度海量搜索数据,我们训练了一个SimNet-BOW-Pairwise语义匹配模型,在一些真实的FAQ问答场景中,该模型效果比基于字面的相似度方法AUC提升5%以上,我们基于百度自建测试集(包含聊天、客服等数据集)和语义匹配数据集(LCQMC)进行评测,效果如下表所示。LCQMC数据集以Accuracy为评测指标,而pairwise模型的输出为相似度,因此我们采用0.958作为分类阈值,相比于基线模型中网络结构同等复杂的CBOW模型(准确率为0.737),我们模型的准确率为0.7532。
| 模型 | 百度知道 | ECOM |QQSIM | UNICOM | LCQMC |
|:-----------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|
| | AUC | AUC | AUC|正逆序比|Accuracy|
|BOW_Pairwise|0.6767|0.7329|0.7650|1.5630|0.7532|
| 模型 | 百度知道 | ECOM |QQSIM | UNICOM |
|:-----------:|:-------------:|:-------------:|:-------------:|:-------------:|
| | AUC | AUC | AUC|正逆序比|
|BOW_Pairwise|0.6767|0.7329|0.7650|1.5630|
#### 测试集说明
| 数据集 | 来源 | 垂类 |
|:-----------:|:-------------:|:-------------:|
|百度知道 | 百度知道问题 | 日常 |
|ECOM|商业问句|金融|
|QQSIM|闲聊对话|日常|
|UNICOM|联通客服|客服|
## 快速开始
#### 版本依赖
本项目依赖于 Paddlepaddle Fluid 1.3.1,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。
......@@ -24,24 +31,14 @@ cd models/PaddleNLP/similarity_net
#### 数据准备
下载经过预处理的数据,运行命令后,data目录下会存在训练集数据示例、集数据示例、测试集数据示例,以及对应词索引字典(term2id.dict)。
```shell
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_dataset-1.0.0.tar.gz
tar xzf simnet_dataset-1.0.0.tar.gz
sh download_data.sh
```
#### 模型准备
我们开源了基于大规模数据训练好的```pairwise```模型(基于bow模型训练),我们提供两种下载方式,模型保在```./model_files/simnet_bow_pairwise_pretrained_model/```下。
##### 方式一:基于PaddleHub命令行工具(PaddleHub[安装方式](https://github.com/PaddlePaddle/PaddleHub))
```shell
mkdir model_files
hub download simnet_bow_pairwise --output_path ./
tar xzf simnet_bow-pairwise-1.0.0.tar.gz -C ./model_files
```
##### 方式二:直接下载
我们开源了基于大规模数据训练好的```pairwise```模型(基于bow模型训练),用户可以通过运行命令下载预训练好的模型,该模型将保存在```./model_files/simnet_bow_pairwise_pretrained_model/```下。
```shell
mkdir model_files
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_bow-pairwise-1.0.0.tar.gz
tar xzf simnet_bow-pairwise-1.0.0.tar.gz -C ./model_files
sh download_pretrained_model.sh
```
#### 评估
我们公开了自建的测试集,包括百度知道、ECOM、QQSIM、UNICOM四个数据集,基于上面的预训练模型,用户可以进入evaluate目录下依次执行下列命令获取测试集评估结果。
```shell
......@@ -162,6 +159,7 @@ python run_classifier.py \
--task_mode ${TASK_MODE} #训练模式,pairwise或pointwise,与相应的配置文件匹配。
--compute_accuracy False \ #是否计算accuracy
--lamda 0.91 \ #pairwise模式计算accuracy时的阈值
--init_checkpoint "" #预加载模型路径
```
### 如何组建自己的模型
用户可以根据自己的需求,组建自定义的模型,具体方法如下所示:
......
......@@ -34,14 +34,12 @@ class SimNetConfig(object):
with open(config_path) as json_file:
config_dict = json.load(json_file)
except Exception:
raise IOError("Error in parsing simnet model config file '%s'" %
config_path)
raise IOError("Error in parsing simnet model config file '%s'" % config_path)
else:
if config_dict["task_mode"] != self.task_mode:
raise ValueError(
"the config '{}' does not match the task_mode '{}'".format(
self.config_path, self.task_mode))
"the config '{}' does not match the task_mode '{}'".format(self.config_path, self.task_mode))
return config_dict
def __getitem__(self, key):
......
#get data
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_dataset-1.0.0.tar.gz
tar xzf simnet_dataset-1.0.0.tar.gz
rm simnet_dataset-1.0.0.tar.gz
......@@ -4,13 +4,7 @@ model_files_path="./model_files"
#get pretrained_bow_pairwise_model
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_bow-pairwise-1.0.0.tar.gz
if [ ! -d $model_files_path ]; then
mkdir $model_files_path
mkdir $model_files_path
fi
tar xzf simnet_bow-pairwise-1.0.0.tar.gz -C $model_files_path
rm simnet_bow-pairwise-1.0.0.tar.gz
#get data
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_dataset-1.0.0.tar.gz
tar xzf simnet_dataset-1.0.0.tar.gz
rm simnet_dataset-1.0.0.tar.gz
rm simnet_bow-pairwise-1.0.0.tar.gz
\ No newline at end of file
......@@ -21,7 +21,7 @@ INIT_CHECKPOINT=./model_files/simnet_bow_pairwise_pretrained_model/
train() {
python run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda false \
--use_cuda False \
--do_train True \
--do_valid True \
--do_test True \
......@@ -34,12 +34,13 @@ train() {
--output_dir ${CKPT_PATH} \
--config_path ${CONFIG_PATH} \
--vocab_path ${VOCAB_PATH} \
--epoch 10 \
--save_steps 1000 \
--validation_steps 100 \
--epoch 40 \
--save_steps 2000 \
--validation_steps 200 \
--compute_accuracy False \
--lamda 0.958 \
--task_mode ${TASK_MODE}
--task_mode ${TASK_MODE}\
--init_checkpoint ""
}
#run_evaluate
evaluate() {
......
......@@ -15,7 +15,7 @@
"""
SimNet utilities.
"""
import argparse
import time
import sys
import re
......@@ -26,20 +26,17 @@ import numpy as np
import logging
import logging.handlers
import paddle.fluid as fluid
import io
"""
******functions for file processing******
"""
def load_vocab(file_path):
"""
load the given vocabulary
"""
vocab = {}
if six.PY3:
f = open(file_path, "r", encoding="utf-8")
else:
f = open(file_path, "r")
f = io.open(file_path, "r", encoding="utf-8")
for line in f:
items = line.strip("\n").split("\t")
if items[0] not in vocab:
......@@ -61,8 +58,7 @@ def get_result_file(args):
"""
with codecs.open(args.test_data_dir, "r", "utf-8") as test_file:
with codecs.open("predictions.txt", "r", "utf-8") as predictions_file:
with codecs.open(args.test_result_path, "w",
"utf-8") as test_result_file:
with codecs.open(args.test_result_path, "w", "utf-8") as test_result_file:
test_datas = [line.strip("\n") for line in test_file]
predictions = [line.strip("\n") for line in predictions_file]
for test_data, prediction in zip(test_datas, predictions):
......@@ -170,6 +166,58 @@ class ArgumentGroup(object):
help=help + ' Default: %(default)s.',
**kwargs)
class ArgConfig(object):
def __init__(self):
parser = argparse.ArgumentParser()
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("config_path", str, None, "Path to the json file for EmoTect model config.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("output_dir", str, None, "Directory path to save checkpoints")
model_g.add_arg("task_mode", str, None, "task mode: pairwise or pointwise")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.")
train_g.add_arg("save_steps", int, 200, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 100, "The steps interval to evaluate model performance.")
log_g = ArgumentGroup(parser, "logging", "logging related")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
log_g.add_arg("verbose_result", bool, True, "Whether to output verbose result.")
log_g.add_arg("test_result_path", str, "test_result", "Directory path to test result.")
log_g.add_arg("infer_result_path", str, "infer_result", "Directory path to infer result.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("train_data_dir", str, None, "Directory path to training data.")
data_g.add_arg("valid_data_dir", str, None, "Directory path to valid data.")
data_g.add_arg("test_data_dir", str, None, "Directory path to testing data.")
data_g.add_arg("infer_data_dir", str, None, "Directory path to infer data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.")
run_type_g.add_arg("task_name", str, None, "The name of task to perform sentiment classification.")
run_type_g.add_arg("do_train", bool, False, "Whether to perform training.")
run_type_g.add_arg("do_valid", bool, False, "Whether to perform dev.")
run_type_g.add_arg("do_test", bool, False, "Whether to perform testing.")
run_type_g.add_arg("do_infer", bool, False, "Whether to perform inference.")
run_type_g.add_arg("compute_accuracy", bool, False, "Whether to compute accuracy.")
run_type_g.add_arg("lamda", float, 0.91, "When task_mode is pairwise, lamda is the threshold for calculating the accuracy.")
custom_g = ArgumentGroup(parser, "customize", "customized options.")
self.custom_g = custom_g
parser.add_argument('--enable_ce',action='store_true',help='If set, run the task with continuous evaluation logs.')
self.parser = parser
def add_arg(self, name, dtype, default, descrip):
self.custom_g.add_arg(name, dtype, default, descrip)
def build_conf(self):
return self.parser.parse_args()
def print_arguments(args):
"""
......@@ -302,7 +350,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
"""
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
if not fluid.io.is_persistable(var):
return False
......@@ -314,3 +362,4 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册