提交 02fef740 编写于 作者: Y yinhaofeng

change

上级 05fe3b93
......@@ -51,7 +51,7 @@ runner:
- name: train_runner
class: train
# num of epochs
epochs: 3
epochs: 2
# device to run training or infer
device: cpu
save_checkpoint_interval: 1 # save model interval of epochs
......@@ -68,7 +68,7 @@ runner:
# device to run training or infer
device: cpu
print_interval: 1
init_model_path: "increment/2" # load model path
init_model_path: "increment/1" # load model path
phases: phase2
# runner will run all the phase in each epoch
......
......@@ -16,28 +16,19 @@ for line in lines:
text = line[0].split(" ") + line[1].split(" ")
for word in text:
if word in word_dict:
word_dict[word] = word_dict[word] + 1
continue
else:
word_dict[word] = 1
word_list = word_dict.items()
word_list = sorted(word_dict.items(), key=lambda item: item[1], reverse=True)
word_list_ids = range(1, len(word_list) + 1)
word_dict = dict(zip([x[0] for x in word_list], word_list_ids))
word_dict[word] = len(word_dict) + 1
f = open("./zhidao", "r")
lines = f.readlines()
f.close()
#划分训练集和测试集
lines = [line.strip().split("\t") for line in lines]
random.shuffle(lines)
train_set = lines[:900]
test_set = lines[900:]
#建立以query为key,以负例为value的字典
neg_dict = {}
for line in train_set:
for line in lines:
if line[2] == "0":
if line[0] in neg_dict:
neg_dict[line[0]].append(line[1])
......@@ -46,31 +37,44 @@ for line in train_set:
#建立以query为key,以正例为value的字典
pos_dict = {}
for line in train_set:
for line in lines:
if line[2] == "1":
if line[0] in pos_dict:
pos_dict[line[0]].append(line[1])
else:
pos_dict[line[0]] = [line[1]]
#训练集整理为query,pos,neg的格式
f = open("train.txt", "w")
for query in pos_dict.keys():
#划分训练集和测试集
query_list = list(pos_dict.keys())
#print(len(query_list))
random.shuffle(query_list)
train_query = query_list[:90]
test_query = query_list[90:]
#获得训练集
train_set = []
for query in train_query:
for pos in pos_dict[query]:
if query not in neg_dict:
continue
for neg in neg_dict[query]:
f.write(str(query) + "\t" + str(pos) + "\t" + str(neg) + "\n")
f.close()
train_set.append([query, pos, neg])
random.shuffle(train_set)
f = open("train.txt", "r")
lines = f.readlines()
f.close()
#获得测试集
test_set = []
for query in test_query:
for pos in pos_dict[query]:
test_set.append([query, pos, 1])
if query not in neg_dict:
continue
for neg in neg_dict[query]:
test_set.append([query, neg, 0])
random.shuffle(test_set)
#训练集中的query,pos,neg转化格式
f = open("train.txt", "w")
for line in lines:
line = line.strip().split("\t")
for line in train_set:
query = line[0].strip().split(" ")
pos = line[1].strip().split(" ")
neg = line[2].strip().split(" ")
......@@ -91,6 +95,7 @@ f.close()
#测试集中的query和pos转化格式
f = open("test.txt", "w")
fa = open("label.txt", "w")
fb = open("testquery.txt", "w")
for line in test_set:
query = line[0].strip().split(" ")
pos = line[1].strip().split(" ")
......@@ -98,12 +103,13 @@ for line in test_set:
query_list = []
for word in query:
query_list.append(word_dict[word])
pos_token = []
pos_list = []
for word in pos:
pos_list.append(word_dict[word])
f.write(' '.join(["0:" + str(x) for x in query_list]) + " " + ' '.join(
["1:" + str(x) for x in pos_list]) + "\n")
fa.write(label + "\n")
fa.write(str(label) + "\n")
fb.write(','.join([str(x) for x in query_list]) + "\n")
f.close()
fa.close()
fb.close()
因为 它太大了无法显示 source diff 。你可以改为 查看blob
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlerec.core.reader import ReaderBase
from paddlerec.core.utils import envs
class Reader(ReaderBase):
def init(self):
self.query_slots = envs.get_global_env("hyper_parameters.query_slots",
None, "train.model")
self.title_slots = envs.get_global_env("hyper_parameters.title_slots",
None, "train.model")
self.all_slots = []
for i in range(self.query_slots):
self.all_slots.append(str(i))
for i in range(self.title_slots):
self.all_slots.append(str(i + self.query_slots))
self._all_slots_dict = dict()
for index, slot in enumerate(self.all_slots):
self._all_slots_dict[slot] = [False, index]
def generate_sample(self, line):
def data_iter():
elements = line.rstrip().split()
padding = 0
output = [(slot, []) for slot in self.all_slots]
for elem in elements:
slot, feasign = elem.split(':')
if not self._all_slots_dict.has_key(slot):
continue
self._all_slots_dict[slot][0] = True
index = self._all_slots_dict[slot][1]
output[index][1].append(int(feasign))
for slot in self._all_slots_dict:
visit, index = self._all_slots_dict[slot]
if visit:
self._all_slots_dict[slot][0] = False
else:
output[index][1].append(padding)
yield output
return data_iter
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlerec.core.reader import ReaderBase
from paddlerec.core.utils import envs
class Reader(ReaderBase):
def init(self):
self.query_slots = envs.get_global_env("hyper_parameters.query_slots",
None, "train.model")
self.title_slots = envs.get_global_env("hyper_parameters.title_slots",
None, "train.model")
self.all_slots = []
for i in range(self.query_slots):
self.all_slots.append(str(i))
for i in range(self.title_slots):
self.all_slots.append(str(i + self.query_slots))
for i in range(self.title_slots):
self.all_slots.append(str(i + self.query_slots + self.title_slots))
self._all_slots_dict = dict()
for index, slot in enumerate(self.all_slots):
self._all_slots_dict[slot] = [False, index]
def generate_sample(self, line):
def data_iter():
elements = line.rstrip().split()
padding = 0
output = [(slot, []) for slot in self.all_slots]
for elem in elements:
slot, feasign = elem.split(':')
if not self._all_slots_dict.has_key(slot):
continue
self._all_slots_dict[slot][0] = True
index = self._all_slots_dict[slot][1]
output[index][1].append(int(feasign))
for slot in self._all_slots_dict:
visit, index = self._all_slots_dict[slot]
if visit:
self._all_slots_dict[slot][0] = False
else:
output[index][1].append(padding)
yield output
return data_iter
......@@ -13,8 +13,10 @@
├── README.md #文档
├── model.py #模型文件
├── config.yaml #配置文件
├── run.sh #运行脚本
├── eval.py #评价脚本
├── run.sh #运行脚本,在效果复现时使用
├── transform.py #整理格式准备计算指标的程序
├── reader.py #读者需要自定义数据集时供读者参考
├── evaluate_reader.py #读者需要自定义数据集时供读者参考
```
注:在阅读该示例前,建议您先了解以下内容:
......@@ -79,7 +81,7 @@ rm -f simnet_dataset-1.0.0.tar.gz
mv data/zhidao ./
rm -rf data
```
3. 本文提供了快速将数据集中的汉字数据处理为可训练格式数据的脚本,您在解压数据集后,可以看见目录中存在一个名为zhidao的文件。然后能可以在python3环境下运行我们提供的preprocess.py文件。即可生成可以直接用于训练的数据目录test.txt,train.txt和label.txt。将其放入train和test目录下以备训练时调用。命令如下:
3. 本文提供了快速将数据集中的汉字数据处理为可训练格式数据的脚本,您在解压数据集后,可以看见目录中存在一个名为zhidao的文件。然后能可以在python3环境下运行我们提供的preprocess.py文件。即可生成可以直接用于训练的数据目录test.txt,train.txt,label.txt和testquery.txt。将其放入train和test目录下以备训练时调用。命令如下:
```
python3 preprocess.py
rm -f ./train/train.txt
......@@ -90,17 +92,36 @@ cd ..
```
4. 退回tagspace目录中,打开文件config.yaml,更改其中的参数
将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径)
将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径)
5. 执行脚本,开始训练.脚本会运行python -m paddlerec.run -m ./config.yaml启动训练,并将结果输出到result文件中。然后启动评价脚本eval.py计算auc
5. 执行脚本,开始训练.脚本会运行python -m paddlerec.run -m ./config.yaml启动训练,并将结果输出到result文件中。然后启动格式整理程序transform,最后计算正逆序比
```
sh run.sh
```
运行结果大致如下:
```
('auc = ', 0.5944897959183673)
................run.................
!!! The CPU_NUM is not specified, you should set CPU_NUM in the environment variable list.
CPU_NUM indicates that how many CPUPlace are used in the current task.
And if this parameter are set as N (equal to the number of physical CPU core) the program may be faster.
export CPU_NUM=32 # for example, set CPU_NUM as number of physical CPU core which is 32.
!!! The default number of CPU_NUM=1.
I0821 14:24:57.255358 7888 parallel_executor.cc:440] The Program will be executed on CPU using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel.
I0821 14:24:57.259166 7888 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1
I0821 14:24:57.262634 7888 parallel_executor.cc:307] Inplace strategy is enabled, when build_strategy.enable_inplace = True
I0821 14:24:57.264791 7888 parallel_executor.cc:375] Garbage collection strategy is enabled, when FLAGS_eager_delete_tensor_gb = 0
103
pnr: 1.17674418605
query_num: 11
pair_num: 468 468
equal_num: 0
正序率: 0.540598290598
253 215
```
6. 提醒:因为采取较小的数据集进行训练和测试,得到指标的浮动程度会比较大。如果得到的指标不合预期,可以多次执行步骤5,即可获得合理的指标。
## 进阶使用
## FAQ
......@@ -5,4 +5,7 @@ grep -i "query_pt_sim" ./result1.txt >./result2.txt
sed '$d' result2.txt >result.txt
rm -f result1.txt
rm -f result2.txt
python eval.py
python transform.py
sort -t $'\t' -k1,1 -k 2nr,2 pair.txt >result.txt
rm -f pair.txt
python ../../../tools/cal_pos_neg.py result.txt
......@@ -24,7 +24,7 @@ num = 0
for line in f.readlines():
num = num + 1
line = line.strip()
label.append(float(line))
label.append(line)
f.close()
print(num)
......@@ -38,5 +38,17 @@ for line in open(filename):
line = line.strip("]")
sim.append(float(line))
auc = sklearn.metrics.roc_auc_score(label, sim)
print("auc = ", auc)
filename = './data/testquery.txt'
f = open(filename, "r")
f.readline()
query = []
for line in f.readlines():
line = line.strip()
query.append(line)
f.close()
filename = 'pair.txt'
f = open(filename, "w")
for i in range(len(sim)):
f.write(str(query[i]) + "\t" + str(sim[i]) + "\t" + str(label[i]) + "\n")
f.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册