未验证 提交 98c94981 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #223 from yinhaofeng/change-dataset

change simnet and dssm data
...@@ -18,6 +18,7 @@ import collections ...@@ -18,6 +18,7 @@ import collections
import os import os
import csv import csv
import re import re
import io
import sys import sys
if six.PY2: if six.PY2:
reload(sys) reload(sys)
...@@ -45,11 +46,11 @@ def build_dict(column_num=2, min_word_freq=0, train_dir="", test_dir=""): ...@@ -45,11 +46,11 @@ def build_dict(column_num=2, min_word_freq=0, train_dir="", test_dir=""):
word_freq = collections.defaultdict(int) word_freq = collections.defaultdict(int)
files = os.listdir(train_dir) files = os.listdir(train_dir)
for fi in files: for fi in files:
with open(os.path.join(train_dir, fi), "r", encoding='utf-8') as f: with io.open(os.path.join(train_dir, fi), "r", encoding='utf-8') as f:
word_freq = word_count(column_num, f, word_freq) word_freq = word_count(column_num, f, word_freq)
files = os.listdir(test_dir) files = os.listdir(test_dir)
for fi in files: for fi in files:
with open(os.path.join(test_dir, fi), "r", encoding='utf-8') as f: with io.open(os.path.join(test_dir, fi), "r", encoding='utf-8') as f:
word_freq = word_count(column_num, f, word_freq) word_freq = word_count(column_num, f, word_freq)
word_freq = [x for x in six.iteritems(word_freq) if x[1] > min_word_freq] word_freq = [x for x in six.iteritems(word_freq) if x[1] > min_word_freq]
...@@ -65,51 +66,51 @@ def write_paddle(text_idx, tag_idx, train_dir, test_dir, output_train_dir, ...@@ -65,51 +66,51 @@ def write_paddle(text_idx, tag_idx, train_dir, test_dir, output_train_dir,
if not os.path.exists(output_train_dir): if not os.path.exists(output_train_dir):
os.mkdir(output_train_dir) os.mkdir(output_train_dir)
for fi in files: for fi in files:
with open(os.path.join(train_dir, fi), "r", encoding='utf-8') as f: with io.open(os.path.join(train_dir, fi), "r", encoding='utf-8') as f:
with open( with io.open(
os.path.join(output_train_dir, fi), "w", os.path.join(output_train_dir, fi), "w",
encoding='utf-8') as wf: encoding='utf-8') as wf:
data_file = csv.reader(f) data_file = csv.reader(f)
for row in data_file: for row in data_file:
tag_raw = re.split(r'\W+', row[0].strip()) tag_raw = re.split(r'\W+', row[0].strip())
pos_index = tag_idx.get(tag_raw[0]) pos_index = tag_idx.get(tag_raw[0])
wf.write(str(pos_index) + ",") wf.write(u"{},".format(str(pos_index)))
text_raw = re.split(r'\W+', row[2].strip()) text_raw = re.split(r'\W+', row[2].strip())
l = [text_idx.get(w) for w in text_raw] l = [text_idx.get(w) for w in text_raw]
for w in l: for w in l:
wf.write(str(w) + " ") wf.write(u"{} ".format(str(w)))
wf.write("\n") wf.write(u"\n")
files = os.listdir(test_dir) files = os.listdir(test_dir)
if not os.path.exists(output_test_dir): if not os.path.exists(output_test_dir):
os.mkdir(output_test_dir) os.mkdir(output_test_dir)
for fi in files: for fi in files:
with open(os.path.join(test_dir, fi), "r", encoding='utf-8') as f: with io.open(os.path.join(test_dir, fi), "r", encoding='utf-8') as f:
with open( with io.open(
os.path.join(output_test_dir, fi), "w", os.path.join(output_test_dir, fi), "w",
encoding='utf-8') as wf: encoding='utf-8') as wf:
data_file = csv.reader(f) data_file = csv.reader(f)
for row in data_file: for row in data_file:
tag_raw = re.split(r'\W+', row[0].strip()) tag_raw = re.split(r'\W+', row[0].strip())
pos_index = tag_idx.get(tag_raw[0]) pos_index = tag_idx.get(tag_raw[0])
wf.write(str(pos_index) + ",") wf.write(u"{},".format(str(pos_index)))
text_raw = re.split(r'\W+', row[2].strip()) text_raw = re.split(r'\W+', row[2].strip())
l = [text_idx.get(w) for w in text_raw] l = [text_idx.get(w) for w in text_raw]
for w in l: for w in l:
wf.write(str(w) + " ") wf.write(u"{} ".format(str(w)))
wf.write("\n") wf.write(u"\n")
def text2paddle(train_dir, test_dir, output_train_dir, output_test_dir, def text2paddle(train_dir, test_dir, output_train_dir, output_test_dir,
output_vocab_text, output_vocab_tag): output_vocab_text, output_vocab_tag):
print("start constuct word dict") print("start constuct word dict")
vocab_text = build_dict(2, 0, train_dir, test_dir) vocab_text = build_dict(2, 0, train_dir, test_dir)
with open(output_vocab_text, "w", encoding='utf-8') as wf: with io.open(output_vocab_text, "w", encoding='utf-8') as wf:
wf.write(str(len(vocab_text)) + "\n") wf.write(u"{}\n".format(str(len(vocab_text))))
vocab_tag = build_dict(0, 0, train_dir, test_dir) vocab_tag = build_dict(0, 0, train_dir, test_dir)
with open(output_vocab_tag, "w", encoding='utf-8') as wf: with io.open(output_vocab_tag, "w", encoding='utf-8') as wf:
wf.write(str(len(vocab_tag)) + "\n") wf.write(u"{}\n".format(str(len(vocab_tag))))
print("construct word dict done\n") print("construct word dict done\n")
write_paddle(vocab_text, vocab_tag, train_dir, test_dir, output_train_dir, write_paddle(vocab_text, vocab_tag, train_dir, test_dir, output_train_dir,
......
...@@ -29,11 +29,12 @@ dataset: ...@@ -29,11 +29,12 @@ dataset:
hyper_parameters: hyper_parameters:
optimizer: optimizer:
class: sgd class: adam
learning_rate: 0.001 learning_rate: 0.001
strategy: async strategy: sync
trigram_d: 1439 trigram_d: 2900
neg_num: 1 neg_num: 1
slice_end: 8
fc_sizes: [300, 300, 128] fc_sizes: [300, 300, 128]
fc_acts: ['tanh', 'tanh', 'tanh'] fc_acts: ['tanh', 'tanh', 'tanh']
...@@ -44,7 +45,7 @@ runner: ...@@ -44,7 +45,7 @@ runner:
- name: train_runner - name: train_runner
class: train class: train
# num of epochs # num of epochs
epochs: 3 epochs: 1
# device to run training or infer # device to run training or infer
device: cpu device: cpu
save_checkpoint_interval: 1 # save model interval of epochs save_checkpoint_interval: 1 # save model interval of epochs
...@@ -54,14 +55,14 @@ runner: ...@@ -54,14 +55,14 @@ runner:
save_inference_feed_varnames: ["query", "doc_pos"] # feed vars of save inference save_inference_feed_varnames: ["query", "doc_pos"] # feed vars of save inference
save_inference_fetch_varnames: ["cos_sim_0.tmp_0"] # fetch vars of save inference save_inference_fetch_varnames: ["cos_sim_0.tmp_0"] # fetch vars of save inference
init_model_path: "" # load model path init_model_path: "" # load model path
print_interval: 2 print_interval: 10
phases: phase1 phases: phase1
- name: infer_runner - name: infer_runner
class: infer class: infer
# device to run training or infer # device to run training or infer
device: cpu device: cpu
print_interval: 1 print_interval: 1
init_model_path: "increment/2" # load model path init_model_path: "increment/0" # load model path
phases: phase2 phases: phase2
# runner will run all the phase in each epoch # runner will run all the phase in each epoch
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
wget https://paddlerec.bj.bcebos.com/dssm%2Fbq.tar.gz
tar xzf dssm%2Fbq.tar.gz
rm -f dssm%2Fbq.tar.gz
mv bq/train.txt ./raw_data.txt
python3 preprocess.py
mkdir big_train
mv train.txt ./big_train
mkdir big_test
mv test.txt ./big_test
#encoding=utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,14 +12,14 @@ ...@@ -11,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#encoding=utf-8
import os import os
import sys import sys
import jieba
import numpy as np import numpy as np
import random import random
f = open("./zhidao", "r") f = open("./raw_data.txt", "r")
lines = f.readlines() lines = f.readlines()
f.close() f.close()
...@@ -26,14 +27,15 @@ f.close() ...@@ -26,14 +27,15 @@ f.close()
word_dict = {} word_dict = {}
for line in lines: for line in lines:
line = line.strip().split("\t") line = line.strip().split("\t")
text = line[0].split(" ") + line[1].split(" ") text = line[0].strip("") + " " + line[1].strip("")
text = jieba.cut(text)
for word in text: for word in text:
if word in word_dict: if word in word_dict:
continue continue
else: else:
word_dict[word] = len(word_dict) + 1 word_dict[word] = len(word_dict) + 1
f = open("./zhidao", "r") f = open("./raw_data.txt", "r")
lines = f.readlines() lines = f.readlines()
f.close() f.close()
...@@ -57,12 +59,13 @@ for line in lines: ...@@ -57,12 +59,13 @@ for line in lines:
else: else:
pos_dict[line[0]] = [line[1]] pos_dict[line[0]] = [line[1]]
print("build dict done")
#划分训练集和测试集 #划分训练集和测试集
query_list = list(pos_dict.keys()) query_list = list(pos_dict.keys())
#print(len(query)) #print(len(query_list))
random.shuffle(query_list) #random.shuffle(query_list)
train_query = query_list[:90] train_query = query_list[:11600]
test_query = query_list[90:] test_query = query_list[11600:]
#获得训练集 #获得训练集
train_set = [] train_set = []
...@@ -73,6 +76,7 @@ for query in train_query: ...@@ -73,6 +76,7 @@ for query in train_query:
for neg in neg_dict[query]: for neg in neg_dict[query]:
train_set.append([query, pos, neg]) train_set.append([query, pos, neg])
random.shuffle(train_set) random.shuffle(train_set)
print("get train_set done")
#获得测试集 #获得测试集
test_set = [] test_set = []
...@@ -84,13 +88,14 @@ for query in test_query: ...@@ -84,13 +88,14 @@ for query in test_query:
for neg in neg_dict[query]: for neg in neg_dict[query]:
test_set.append([query, neg, 0]) test_set.append([query, neg, 0])
random.shuffle(test_set) random.shuffle(test_set)
print("get test_set done")
#训练集中的query,pos,neg转化为词袋 #训练集中的query,pos,neg转化为词袋
f = open("train.txt", "w") f = open("train.txt", "w")
for line in train_set: for line in train_set:
query = line[0].strip().split(" ") query = jieba.cut(line[0].strip())
pos = line[1].strip().split(" ") pos = jieba.cut(line[1].strip())
neg = line[2].strip().split(" ") neg = jieba.cut(line[2].strip())
query_token = [0] * (len(word_dict) + 1) query_token = [0] * (len(word_dict) + 1)
for word in query: for word in query:
query_token[word_dict[word]] = 1 query_token[word_dict[word]] = 1
...@@ -109,8 +114,8 @@ f.close() ...@@ -109,8 +114,8 @@ f.close()
f = open("test.txt", "w") f = open("test.txt", "w")
fa = open("label.txt", "w") fa = open("label.txt", "w")
for line in test_set: for line in test_set:
query = line[0].strip().split(" ") query = jieba.cut(line[0].strip())
pos = line[1].strip().split(" ") pos = jieba.cut(line[1].strip())
label = line[2] label = line[2]
query_token = [0] * (len(word_dict) + 1) query_token = [0] * (len(word_dict) + 1)
for word in query: for word in query:
......
...@@ -29,6 +29,7 @@ class Model(ModelBase): ...@@ -29,6 +29,7 @@ class Model(ModelBase):
self.hidden_acts = envs.get_global_env("hyper_parameters.fc_acts") self.hidden_acts = envs.get_global_env("hyper_parameters.fc_acts")
self.learning_rate = envs.get_global_env( self.learning_rate = envs.get_global_env(
"hyper_parameters.learning_rate") "hyper_parameters.learning_rate")
self.slice_end = envs.get_global_env("hyper_parameters.slice_end")
def input_data(self, is_infer=False, **kwargs): def input_data(self, is_infer=False, **kwargs):
query = fluid.data( query = fluid.data(
...@@ -94,7 +95,7 @@ class Model(ModelBase): ...@@ -94,7 +95,7 @@ class Model(ModelBase):
prob = fluid.layers.softmax(concat_Rs, axis=1) prob = fluid.layers.softmax(concat_Rs, axis=1)
hit_prob = fluid.layers.slice( hit_prob = fluid.layers.slice(
prob, axes=[0, 1], starts=[0, 0], ends=[8, 1]) prob, axes=[0, 1], starts=[0, 0], ends=[self.slice_end, 1])
loss = -fluid.layers.reduce_sum(fluid.layers.log(hit_prob)) loss = -fluid.layers.reduce_sum(fluid.layers.log(hit_prob))
avg_cost = fluid.layers.mean(x=loss) avg_cost = fluid.layers.mean(x=loss)
self._cost = avg_cost self._cost = avg_cost
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
``` ```
├── data #样例数据 ├── data #样例数据
├── train ├── train
├── train.txt #训练数据样例 ├── train.txt #训练数据样例
├── test ├── test
├── test.txt #测试数据样例 ├── test.txt #测试数据样例
├── preprocess.py #数据处理程序 ├── preprocess.py #数据处理程序
├── data_process #数据一键处理脚本
├── __init__.py ├── __init__.py
├── README.md #文档 ├── README.md #文档
├── model.py #模型文件 ├── model.py #模型文件
...@@ -46,13 +47,19 @@ Query 和 Doc 的语义相似性可以用这两个向量的 cosine 距离表示 ...@@ -46,13 +47,19 @@ Query 和 Doc 的语义相似性可以用这两个向量的 cosine 距离表示
<p> <p>
## 数据准备 ## 数据准备
我们公开了自建的测试集,包括百度知道、ECOM、QQSIM、UNICOM 四个数据集。这里我们选取百度知道数据集来进行训练。执行以下命令可以获取上述数据集。 BQ是一个智能客服中文问句匹配数据集,该数据集是自动问答系统语料,共有120,000对句子对,并标注了句子对相似度值。数据中存在错别字、语法不规范等问题,但更加贴近工业场景。执行以下命令可以获取上述数据集。
``` ```
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_dataset-1.0.0.tar.gz wget https://paddlerec.bj.bcebos.com/dssm%2Fbq.tar.gz
tar xzf simnet_dataset-1.0.0.tar.gz tar xzf dssm%2Fbq.tar.gz
rm simnet_dataset-1.0.0.tar.gz rm -f dssm%2Fbq.tar.gz
```
数据集样例:
```
请问一天是否都是限定只能转入或转出都是五万。 微众多少可以赎回短期理财 0
微粒咨询电话号码多少 你们的人工客服电话是多少 1
已经在银行换了新预留号码。 我现在换了电话号码,这个需要更换吗 1
每个字段以tab键分隔,第1,2列表示两个文本。第3列表示类别(0或1,0表示两个文本不相似,1表示两个文本相似)。
``` ```
## 运行环境 ## 运行环境
PaddlePaddle>=1.7.2 PaddlePaddle>=1.7.2
...@@ -120,21 +127,24 @@ PaddleRec Finish ...@@ -120,21 +127,24 @@ PaddleRec Finish
2. 在data目录下载并解压数据集,命令如下: 2. 在data目录下载并解压数据集,命令如下:
``` ```
cd data cd data
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_dataset-1.0.0.tar.gz wget https://paddlerec.bj.bcebos.com/dssm%2Fbq.tar.gz
tar xzf simnet_dataset-1.0.0.tar.gz tar xzf dssm%2Fbq.tar.gz
rm simnet_dataset-1.0.0.tar.gz rm -f dssm%2Fbq.tar.gz
``` ```
3. 本文提供了快速将数据集中的汉字数据处理为可训练格式数据的脚本,您在解压数据集后,可以看见目录中存在一个名为zhidao的文件。然后能可以在python3环境下运行我们提供的preprocess.py文件。即可生成可以直接用于训练的数据目录test.txt,train.txt和label.txt。将其放入train和test目录下以备训练时调用。命令如下: 3. 本文提供了快速将数据集中的汉字数据处理为可训练格式数据的脚本,您在解压数据集后,可以看见目录中存在一个名为bq的目录。将其中的train.txt文件移动到data目录下,然后可以在python3环境下运行我们提供的preprocess.py文件。即可生成可以直接用于训练的数据目录test.txt,train.txt和label.txt。将其放入train和test目录下以备训练时调用。生成时间较长,请耐心等待。命令如下:
``` ```
mv data/zhidao ./ mv bq/train.txt ./raw_data.txt
rm -rf data
python3 preprocess.py python3 preprocess.py
rm -f ./train/train.txt mkdir big_train
mv train.txt ./train mv train.txt ./big_train
rm -f ./test/test.txt mkdir big_test
mv test.txt test mv test.txt ./big_test
cd .. cd ..
``` ```
也可以使用我们提供的一键数据处理脚本data_process.sh
```
sh data_process.sh
```
经过预处理的格式: 经过预处理的格式:
训练集为三个稀疏的BOW方式的向量:query,pos,neg 训练集为三个稀疏的BOW方式的向量:query,pos,neg
测试集为两个稀疏的BOW方式的向量:query,pos 测试集为两个稀疏的BOW方式的向量:query,pos
...@@ -144,8 +154,10 @@ label.txt中对应的测试集中的标签 ...@@ -144,8 +154,10 @@ label.txt中对应的测试集中的标签
将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径) 将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径)
将dataset_train中的batch_size从8改为128 将dataset_train中的batch_size从8改为128
将文件model.py中的 hit_prob = fluid.layers.slice(prob, axes=[0, 1], starts=[0, 0], ends=[8, 1]) 将hyper_parameters中的slice_end从8改为128.当您需要改变batchsize的时候,这个参数也需要随之变化
改为hit_prob = fluid.layers.slice(prob, axes=[0, 1], starts=[0, 0], ends=[128, 1]).当您需要改变batchsize的时候,end中第一个参数也需要随之变化 将dataset_train中的data_path改为{workspace}/data/big_train
将dataset_infer中的data_path改为{workspace}/data/big_test
将hyper_parameters中的trigram_d改为5913
5. 执行脚本,开始训练.脚本会运行python -m paddlerec.run -m ./config.yaml启动训练,并将结果输出到result文件中。然后启动transform.py整合数据,最后计算出正逆序指标: 5. 执行脚本,开始训练.脚本会运行python -m paddlerec.run -m ./config.yaml启动训练,并将结果输出到result文件中。然后启动transform.py整合数据,最后计算出正逆序指标:
``` ```
...@@ -155,26 +167,14 @@ sh run.sh ...@@ -155,26 +167,14 @@ sh run.sh
输出结果示例: 输出结果示例:
``` ```
................run................. ................run.................
!!! The CPU_NUM is not specified, you should set CPU_NUM in the environment variable list. 8989
CPU_NUM indicates that how many CPUPlace are used in the current task. pnr:2.75621659307
And if this parameter are set as N (equal to the number of physical CPU core) the program may be faster. query_num:1369
pair_num:16240 , 16240
export CPU_NUM=32 # for example, set CPU_NUM as number of physical CPU core which is 32. equal_num:77
正序率: 0.733774670544
!!! The default number of CPU_NUM=1. pos_num: 11860 , neg_num: 4303
I0821 07:16:04.512531 32200 parallel_executor.cc:440] The Program will be executed on CPU using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel. ```
I0821 07:16:04.515708 32200 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1
I0821 07:16:04.518872 32200 parallel_executor.cc:307] Inplace strategy is enabled, when build_strategy.enable_inplace = True
I0821 07:16:04.520995 32200 parallel_executor.cc:375] Garbage collection strategy is enabled, when FLAGS_eager_delete_tensor_gb = 0
75
pnr: 2.25581395349
query_num: 11
pair_num: 184 184
equal_num: 44
正序率: 0.692857142857
97 43
```
6. 提醒:因为采取较小的数据集进行训练和测试,得到指标的浮动程度会比较大。如果得到的指标不合预期,可以多次执行步骤5,即可获得合理的指标。
## 进阶使用 ## 进阶使用
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
#!/bin/bash #!/bin/bash
echo "................run................." echo "................run................."
python -m paddlerec.run -m ./config.yaml >result1.txt python -m paddlerec.run -m ./config.yaml &> result1.txt
grep -i "query_doc_sim" ./result1.txt >./result2.txt grep -i "query_doc_sim" ./result1.txt >./result2.txt
sed '$d' result2.txt >result.txt sed '$d' result2.txt >result.txt
rm -f result1.txt rm -f result1.txt
......
...@@ -32,13 +32,13 @@ filename = './result.txt' ...@@ -32,13 +32,13 @@ filename = './result.txt'
sim = [] sim = []
for line in open(filename): for line in open(filename):
line = line.strip().split(",") line = line.strip().split(",")
line[1] = line[1].split(":") line[3] = line[3].split(":")
line = line[1][1].strip(" ") line = line[3][1].strip(" ")
line = line.strip("[") line = line.strip("[")
line = line.strip("]") line = line.strip("]")
sim.append(float(line)) sim.append(float(line))
filename = './data/test/test.txt' filename = './data/big_test/test.txt'
f = open(filename, "r") f = open(filename, "r")
f.readline() f.readline()
query = [] query = []
......
...@@ -106,7 +106,7 @@ def make_train(): ...@@ -106,7 +106,7 @@ def make_train():
pair_list.append((d1, high_d2, low_d2)) pair_list.append((d1, high_d2, low_d2))
print('Pair Instance Count:', len(pair_list)) print('Pair Instance Count:', len(pair_list))
f = open("./data/train/train.txt", "w") f = open("./data/big_train/train.txt", "w")
for batch in range(800): for batch in range(800):
X1 = np.zeros((batch_size * 2, data1_maxlen), dtype=np.int32) X1 = np.zeros((batch_size * 2, data1_maxlen), dtype=np.int32)
X2 = np.zeros((batch_size * 2, data2_maxlen), dtype=np.int32) X2 = np.zeros((batch_size * 2, data2_maxlen), dtype=np.int32)
...@@ -131,7 +131,7 @@ def make_train(): ...@@ -131,7 +131,7 @@ def make_train():
def make_test(): def make_test():
rel = read_relation(filename=os.path.join(Letor07Path, rel = read_relation(filename=os.path.join(Letor07Path,
'relation.test.fold1.txt')) 'relation.test.fold1.txt'))
f = open("./data/test/test.txt", "w") f = open("./data/big_test/test.txt", "w")
for label, d1, d2 in rel: for label, d1, d2 in rel:
X1 = np.zeros(data1_maxlen, dtype=np.int32) X1 = np.zeros(data1_maxlen, dtype=np.int32)
X2 = np.zeros(data2_maxlen, dtype=np.int32) X2 = np.zeros(data2_maxlen, dtype=np.int32)
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
echo "...........load data................." echo "...........load data................."
wget --no-check-certificate 'https://paddlerec.bj.bcebos.com/match_pyramid/match_pyramid_data.tar.gz' wget --no-check-certificate 'https://paddlerec.bj.bcebos.com/match_pyramid/match_pyramid_data.tar.gz'
mv ./match_pyramid_data.tar.gz ./data mv ./match_pyramid_data.tar.gz ./data
rm -rf ./data/relation.test.fold1.txt ./data/realtion.train.fold1.txt rm -rf ./data/relation.test.fold1.txt
tar -xvf ./data/match_pyramid_data.tar.gz tar -xvf ./data/match_pyramid_data.tar.gz
mkdir ./data/big_train
mkdir ./data/big_test
echo "...........data process..............." echo "...........data process..............."
python ./data/process.py python ./data/process.py
...@@ -49,8 +49,8 @@ filename = './result.txt' ...@@ -49,8 +49,8 @@ filename = './result.txt'
pred = [] pred = []
for line in open(filename): for line in open(filename):
line = line.strip().split(",") line = line.strip().split(",")
line[1] = line[1].split(":") line[3] = line[3].split(":")
line = line[1][1].strip(" ") line = line[3][1].strip(" ")
line = line.strip("[") line = line.strip("[")
line = line.strip("]") line = line.strip("]")
pred.append(float(line)) pred.append(float(line))
......
...@@ -56,10 +56,10 @@ ...@@ -56,10 +56,10 @@
4.嵌入层文件:我们将预训练的词向量存储在嵌入文件中。例如:embed_wiki-pdc_d50_norm 4.嵌入层文件:我们将预训练的词向量存储在嵌入文件中。例如:embed_wiki-pdc_d50_norm
## 运行环境 ## 运行环境
PaddlePaddle>=1.7.2 PaddlePaddle>=1.7.2
python 2.7/3.5/3.6/3.7 python 2.7/3.5/3.6/3.7
PaddleRec >=0.1 PaddleRec >=0.1
os : windows/linux/macos os : windows/linux/macos
## 快速开始 ## 快速开始
...@@ -72,7 +72,7 @@ python -m paddlerec.run -m models/match/match-pyramid/config.yaml ...@@ -72,7 +72,7 @@ python -m paddlerec.run -m models/match/match-pyramid/config.yaml
## 论文复现 ## 论文复现
1. 确认您当前所在目录为PaddleRec/models/match/match-pyramid 1. 确认您当前所在目录为PaddleRec/models/match/match-pyramid
2. 本文提供了原数据集的下载以及一键生成训练和测试数据的预处理脚本,您可以直接一键运行:bash data_process.sh 2. 本文提供了原数据集的下载以及一键生成训练和测试数据的预处理脚本,您可以直接一键运行:bash data_process.sh
执行该脚本,会从国内源的服务器上下载Letor07数据集,删除掉data文件夹中原有的relation.test.fold1.txt和relation.train.fold1.txt,并将完整的数据集解压到data文件夹。随后运行 process.py 将全量训练数据放置于`./data/train`,全量测试数据放置于`./data/test`。并生成用于初始化embedding层的embedding.npy文件 执行该脚本,会从国内源的服务器上下载Letor07数据集,并将完整的数据集解压到data文件夹。随后运行 process.py 将全量训练数据放置于`./data/big_train`,全量测试数据放置于`./data/big_test`。并生成用于初始化embedding层的embedding.npy文件
执行该脚本的理想输出为: 执行该脚本的理想输出为:
``` ```
bash data_process.sh bash data_process.sh
...@@ -123,6 +123,8 @@ data/embed_wiki-pdc_d50_norm ...@@ -123,6 +123,8 @@ data/embed_wiki-pdc_d50_norm
3. 打开文件config.yaml,更改其中的参数 3. 打开文件config.yaml,更改其中的参数
将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径) 将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径)
将dataset_train下的data_path参数改为{workspace}/data/big_train
将dataset_infer下的data_path参数改为{workspace}/data/big_test
4. 随后,您直接一键运行:bash run.sh 即可得到复现的论文效果 4. 随后,您直接一键运行:bash run.sh 即可得到复现的论文效果
执行该脚本后,会执行python -m paddlerec.run -m ./config.yaml 命令开始训练并测试模型,将测试的结果保存到result.txt文件,最后通过执行eval.py进行评估得到数据的map指标 执行该脚本后,会执行python -m paddlerec.run -m ./config.yaml 命令开始训练并测试模型,将测试的结果保存到result.txt文件,最后通过执行eval.py进行评估得到数据的map指标
...@@ -131,7 +133,7 @@ data/embed_wiki-pdc_d50_norm ...@@ -131,7 +133,7 @@ data/embed_wiki-pdc_d50_norm
..............test................. ..............test.................
13651 13651
336 336
('map=', 0.420878322843591) ('map=', 0.3993127885738651)
``` ```
## 进阶使用 ## 进阶使用
......
#!/bin/bash #!/bin/bash
echo "................run................." echo "................run................."
python -m paddlerec.run -m ./config.yaml >result1.txt python -m paddlerec.run -m ./config.yaml &>result1.txt
grep -i "prediction" ./result1.txt >./result.txt grep -i "prediction" ./result1.txt >./result2.txt
sed '$d' result2.txt >result.txt
rm -f result2.txt
rm -f result1.txt rm -f result1.txt
python eval.py python eval.py
...@@ -26,19 +26,19 @@ dataset: ...@@ -26,19 +26,19 @@ dataset:
batch_size: 1 batch_size: 1
type: DataLoader # or QueueDataset type: DataLoader # or QueueDataset
data_path: "{workspace}/data/test" data_path: "{workspace}/data/test"
sparse_slots: "1 2" sparse_slots: "0 1"
# hyper parameters of user-defined network # hyper parameters of user-defined network
hyper_parameters: hyper_parameters:
optimizer: optimizer:
class: Adam class: Adam
learning_rate: 0.0001 learning_rate: 0.001
strategy: async strategy: sync
query_encoder: "gru" query_encoder: "gru"
title_encoder: "gru" title_encoder: "gru"
query_encode_dim: 128 query_encode_dim: 128
title_encode_dim: 128 title_encode_dim: 128
sparse_feature_dim: 1439 sparse_feature_dim: 6327
embedding_dim: 128 embedding_dim: 128
hidden_size: 128 hidden_size: 128
margin: 0.1 margin: 0.1
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
wget https://paddlerec.bj.bcebos.com/dssm%2Fbq.tar.gz
tar xzf dssm%2Fbq.tar.gz
rm -f dssm%2Fbq.tar.gz
mv bq/train.txt ./raw_data.txt
python3 preprocess.py
mkdir big_train
mv train.txt ./big_train
mkdir big_test
mv test.txt ./big_test
#encoding=utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,14 +12,14 @@ ...@@ -11,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#encoding=utf-8
import os import os
import sys import sys
import jieba
import numpy as np import numpy as np
import random import random
f = open("./zhidao", "r") f = open("./raw_data.txt", "r")
lines = f.readlines() lines = f.readlines()
f.close() f.close()
...@@ -26,14 +27,15 @@ f.close() ...@@ -26,14 +27,15 @@ f.close()
word_dict = {} word_dict = {}
for line in lines: for line in lines:
line = line.strip().split("\t") line = line.strip().split("\t")
text = line[0].split(" ") + line[1].split(" ") text = line[0].strip("") + line[1].strip("")
text = jieba.cut(text)
for word in text: for word in text:
if word in word_dict: if word in word_dict:
continue continue
else: else:
word_dict[word] = len(word_dict) + 1 word_dict[word] = len(word_dict) + 1
f = open("./zhidao", "r") f = open("./raw_data.txt", "r")
lines = f.readlines() lines = f.readlines()
f.close() f.close()
...@@ -59,10 +61,10 @@ for line in lines: ...@@ -59,10 +61,10 @@ for line in lines:
#划分训练集和测试集 #划分训练集和测试集
query_list = list(pos_dict.keys()) query_list = list(pos_dict.keys())
#print(len(query_list)) print(len(query_list))
random.shuffle(query_list) random.shuffle(query_list)
train_query = query_list[:90] train_query = query_list[:11600]
test_query = query_list[90:] test_query = query_list[11600:]
#获得训练集 #获得训练集
train_set = [] train_set = []
...@@ -88,9 +90,9 @@ random.shuffle(test_set) ...@@ -88,9 +90,9 @@ random.shuffle(test_set)
#训练集中的query,pos,neg转化格式 #训练集中的query,pos,neg转化格式
f = open("train.txt", "w") f = open("train.txt", "w")
for line in train_set: for line in train_set:
query = line[0].strip().split(" ") query = jieba.cut(line[0].strip())
pos = line[1].strip().split(" ") pos = jieba.cut(line[1].strip())
neg = line[2].strip().split(" ") neg = jieba.cut(line[2].strip())
query_list = [] query_list = []
for word in query: for word in query:
query_list.append(word_dict[word]) query_list.append(word_dict[word])
...@@ -110,8 +112,8 @@ f = open("test.txt", "w") ...@@ -110,8 +112,8 @@ f = open("test.txt", "w")
fa = open("label.txt", "w") fa = open("label.txt", "w")
fb = open("testquery.txt", "w") fb = open("testquery.txt", "w")
for line in test_set: for line in test_set:
query = line[0].strip().split(" ") query = jieba.cut(line[0].strip())
pos = line[1].strip().split(" ") pos = jieba.cut(line[1].strip())
label = line[2] label = line[2]
query_list = [] query_list = []
for word in query: for word in query:
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
``` ```
├── data #样例数据 ├── data #样例数据
├── train ├── train
├── train.txt #训练数据样例 ├── train.txt #训练数据样例
├── test ├── test
├── test.txt #测试数据样例 ├── test.txt #测试数据样例
├── preprocess.py #数据处理程序 ├── preprocess.py #数据处理程序
├── data_process.sh #一键数据处理脚本
├── __init__.py ├── __init__.py
├── README.md #文档 ├── README.md #文档
├── model.py #模型文件 ├── model.py #模型文件
...@@ -42,14 +43,20 @@ ...@@ -42,14 +43,20 @@
<p> <p>
## 数据准备 ## 数据准备
我们公开了自建的测试集,包括百度知道、ECOM、QQSIM、UNICOM 四个数据集。这里我们选取百度知道数据集来进行训练。执行以下命令可以获取上述数据集。 BQ是一个智能客服中文问句匹配数据集,该数据集是自动问答系统语料,共有120,000对句子对,并标注了句子对相似度值。数据中存在错别字、语法不规范等问题,但更加贴近工业场景。执行以下命令可以获取上述数据集。
``` ```
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_dataset-1.0.0.tar.gz wget https://paddlerec.bj.bcebos.com/dssm%2Fbq.tar.gz
tar xzf simnet_dataset-1.0.0.tar.gz tar xzf dssm%2Fbq.tar.gz
rm simnet_dataset-1.0.0.tar.gz rm -f dssm%2Fbq.tar.gz
``` ```
数据集样例:
数据格式为一个标识句子的slot,后跟一个句子中词的token。两者形成{slot:token}的形式标识一个词: ```
请问一天是否都是限定只能转入或转出都是五万。 微众多少可以赎回短期理财 0
微粒咨询电话号码多少 你们的人工客服电话是多少 1
已经在银行换了新预留号码。 我现在换了电话号码,这个需要更换吗 1
每个字段以tab键分隔,第1,2列表示两个文本。第3列表示类别(0或1,0表示两个文本不相似,1表示两个文本相似)。
```
最终输出的数据格式为一个标识句子的slot,后跟一个句子中词的token。两者形成{slot:token}的形式标识一个词:
``` ```
0:358 0:206 0:205 0:250 0:9 0:3 0:207 0:10 0:330 0:164 1:1144 1:217 1:206 1:9 1:3 1:207 1:10 1:398 1:2 2:217 2:206 2:9 2:3 2:207 2:10 2:398 2:2 0:358 0:206 0:205 0:250 0:9 0:3 0:207 0:10 0:330 0:164 1:1144 1:217 1:206 1:9 1:3 1:207 1:10 1:398 1:2 2:217 2:206 2:9 2:3 2:207 2:10 2:398 2:2
0:358 0:206 0:205 0:250 0:9 0:3 0:207 0:10 0:330 0:164 1:951 1:952 1:206 1:9 1:3 1:207 1:10 1:398 2:217 2:206 2:9 2:3 2:207 2:10 2:398 2:2 0:358 0:206 0:205 0:250 0:9 0:3 0:207 0:10 0:330 0:164 1:951 1:952 1:206 1:9 1:3 1:207 1:10 1:398 2:217 2:206 2:9 2:3 2:207 2:10 2:398 2:2
...@@ -75,24 +82,29 @@ python -m paddlerec.run -m models/match/multiview-simnet/config.yaml ...@@ -75,24 +82,29 @@ python -m paddlerec.run -m models/match/multiview-simnet/config.yaml
2. 在data目录下载并解压数据集,命令如下: 2. 在data目录下载并解压数据集,命令如下:
``` ```
cd data cd data
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_dataset-1.0.0.tar.gz wget https://paddlerec.bj.bcebos.com/dssm%2Fbq.tar.gz
tar xzf simnet_dataset-1.0.0.tar.gz tar xzf dssm%2Fbq.tar.gz
rm -f simnet_dataset-1.0.0.tar.gz rm -f dssm%2Fbq.tar.gz
mv data/zhidao ./ mv bq/train.txt ./raw_data.txt
rm -rf data
``` ```
3. 本文提供了快速将数据集中的汉字数据处理为可训练格式数据的脚本,您在解压数据集后,可以看见目录中存在一个名为zhidao的文件。然后能可以在python3环境下运行我们提供的preprocess.py文件。即可生成可以直接用于训练的数据目录test.txt,train.txt,label.txt和testquery.txt。将其放入train和test目录下以备训练时调用。命令如下: 3. 本文提供了快速将数据集中的汉字数据处理为可训练格式数据的脚本,您在解压数据集后,可以看见目录中存在一个名为bq的目录。将其中的train.txt文件移动到data目录下。然后可以在python3环境下运行我们提供的preprocess.py文件。即可生成可以直接用于训练的数据目录test.txt,train.txt,label.txt和testquery.txt。将其放入train和test目录下以备训练时调用。生成时间较长,请耐心等待。命令如下:
``` ```
python3 preprocess.py python3 preprocess.py
rm -f ./train/train.txt mkdir big_train
mv train.txt ./train mv train.txt ./big_train
rm -f ./test/test.txt mkdir big_test
mv test.txt ./test mv test.txt ./big_test
cd .. cd ..
``` ```
4. 退回tagspace目录中,打开文件config.yaml,更改其中的参数 也可以使用我们提供的一键数据处理脚本data_process.sh
```
sh data_process.sh
```
4. 退回multiview-simnet目录中,打开文件config.yaml,更改其中的参数
将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径) 将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径)
将dataset_train中的data_path改为{workspace}/data/big_train
将dataset_infer中的data_path改为{workspace}/data/big_test
5. 执行脚本,开始训练.脚本会运行python -m paddlerec.run -m ./config.yaml启动训练,并将结果输出到result文件中。然后启动格式整理程序transform,最后计算正逆序比: 5. 执行脚本,开始训练.脚本会运行python -m paddlerec.run -m ./config.yaml启动训练,并将结果输出到result文件中。然后启动格式整理程序transform,最后计算正逆序比:
``` ```
...@@ -102,26 +114,14 @@ sh run.sh ...@@ -102,26 +114,14 @@ sh run.sh
运行结果大致如下: 运行结果大致如下:
``` ```
................run................. ................run.................
!!! The CPU_NUM is not specified, you should set CPU_NUM in the environment variable list. 8902
CPU_NUM indicates that how many CPUPlace are used in the current task. pnr: 13.6785350966
And if this parameter are set as N (equal to the number of physical CPU core) the program may be faster. query_num: 1371
pair_num: 14429 14429
export CPU_NUM=32 # for example, set CPU_NUM as number of physical CPU core which is 32.
!!! The default number of CPU_NUM=1.
I0821 14:24:57.255358 7888 parallel_executor.cc:440] The Program will be executed on CPU using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel.
I0821 14:24:57.259166 7888 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1
I0821 14:24:57.262634 7888 parallel_executor.cc:307] Inplace strategy is enabled, when build_strategy.enable_inplace = True
I0821 14:24:57.264791 7888 parallel_executor.cc:375] Garbage collection strategy is enabled, when FLAGS_eager_delete_tensor_gb = 0
103
pnr: 1.17674418605
query_num: 11
pair_num: 468 468
equal_num: 0 equal_num: 0
正序率: 0.540598290598 正序率: 0.931873310694
253 215 13446 983
``` ```
6. 提醒:因为采取较小的数据集进行训练和测试,得到指标的浮动程度会比较大。如果得到的指标不合预期,可以多次执行步骤5,即可获得合理的指标。
## 进阶使用 ## 进阶使用
## FAQ ## FAQ
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#!/bin/bash #!/bin/bash
echo "................run................." echo "................run................."
python -m paddlerec.run -m ./config.yaml >result1.txt python -m paddlerec.run -m ./config.yaml &>result1.txt
grep -i "query_pt_sim" ./result1.txt >./result2.txt grep -i "query_pt_sim" ./result1.txt >./result2.txt
sed '$d' result2.txt >result.txt sed '$d' result2.txt >result.txt
rm -f result1.txt rm -f result1.txt
......
...@@ -31,8 +31,9 @@ filename = './result.txt' ...@@ -31,8 +31,9 @@ filename = './result.txt'
sim = [] sim = []
for line in open(filename): for line in open(filename):
line = line.strip().split(",") line = line.strip().split(",")
line[1] = line[1].split(":") print(line)
line = line[1][1].strip(" ") line[3] = line[3].split(":")
line = line[3][1].strip(" ")
line = line.strip("[") line = line.strip("[")
line = line.strip("]") line = line.strip("]")
sim.append(float(line)) sim.append(float(line))
...@@ -49,5 +50,6 @@ f.close() ...@@ -49,5 +50,6 @@ f.close()
filename = 'pair.txt' filename = 'pair.txt'
f = open(filename, "w") f = open(filename, "w")
for i in range(len(sim)): for i in range(len(sim)):
print(i)
f.write(str(query[i]) + "\t" + str(sim[i]) + "\t" + str(label[i]) + "\n") f.write(str(query[i]) + "\t" + str(sim[i]) + "\t" + str(label[i]) + "\n")
f.close() f.close()
#encoding=utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,8 +12,6 @@ ...@@ -11,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#!/usr/bin/python
#-*- coding:utf-8 -*-
""" """
docstring docstring
""" """
...@@ -21,10 +20,10 @@ import os ...@@ -21,10 +20,10 @@ import os
import sys import sys
if len(sys.argv) < 2: if len(sys.argv) < 2:
print "usage:python %s input" % (sys.argv[0]) print("usage:python {} input".format(sys.argv[0]))
sys.exit(-1) sys.exit(-1)
fin = file(sys.argv[1]) fin = open(sys.argv[1])
pos_num = 0 pos_num = 0
neg_num = 0 neg_num = 0
...@@ -42,15 +41,15 @@ for line in fin: ...@@ -42,15 +41,15 @@ for line in fin:
cols = line.strip().split("\t") cols = line.strip().split("\t")
cnt += 1 cnt += 1
if cnt % 500000 == 0: if cnt % 500000 == 0:
print "cnt:", cnt, 1.0 * pos_num / neg_num print("cnt:{}".format(1.0 * pos_num / neg_num))
if len(cols) != 3: if len(cols) != 3:
continue continue
cur_query = cols[0] cur_query = cols[0]
if cur_query != last_query: if cur_query != last_query:
query_num += 1 query_num += 1
for i in xrange(0, len(score_list)): for i in range(0, len(score_list)):
for j in xrange(i + 1, len(score_list)): for j in range(i + 1, len(score_list)):
if label_list[i] == label_list[j]: if label_list[i] == label_list[j]:
continue continue
pair_num += 1 pair_num += 1
...@@ -74,8 +73,8 @@ for line in fin: ...@@ -74,8 +73,8 @@ for line in fin:
fin.close() fin.close()
for i in xrange(0, len(score_list)): for i in range(0, len(score_list)):
for j in xrange(i + 1, len(score_list)): for j in range(i + 1, len(score_list)):
if label_list[i] == label_list[j]: if label_list[i] == label_list[j]:
continue continue
pair_num += 1 pair_num += 1
...@@ -89,9 +88,9 @@ for i in xrange(0, len(score_list)): ...@@ -89,9 +88,9 @@ for i in xrange(0, len(score_list)):
equal_num += 1 equal_num += 1
if neg_num > 0: if neg_num > 0:
print "pnr:", 1.0 * pos_num / neg_num print("pnr:{}".format(1.0 * pos_num / neg_num))
print "query_num:", query_num print("query_num:{}".format(query_num))
print "pair_num:", pos_num + neg_num + equal_num, pair_num print("pair_num:{} , {}".format(pos_num + neg_num + equal_num, pair_num))
print "equal_num:", equal_num print("equal_num:{}".format(equal_num))
print "正序率:", 1.0 * pos_num / (pos_num + neg_num) print("正序率: {}".format(1.0 * pos_num / (pos_num + neg_num)))
print pos_num, neg_num print("pos_num: {} , neg_num: {}".format(pos_num, neg_num))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册