From b5d37692e813023f7b9c02f40fa41d922b12dbfc Mon Sep 17 00:00:00 2001 From: zhangcr <737772005@qq.com> Date: Fri, 20 Jan 2017 15:34:44 +0800 Subject: [PATCH] first PR of query-relationship --- query_relationship/.gitignore | 8 ++++ query_relationship/data/getdata.sh | 35 ++++++++++++++ query_relationship/data/test.list | 1 + query_relationship/dataprovider.py | 70 ++++++++++++++++++++++++++++ query_relationship/evaluate.py | 35 ++++++++++++++ query_relationship/predict.sh | 25 ++++++++++ query_relationship/train.sh | 24 ++++++++++ query_relationship/trainer_config.py | 52 +++++++++++++++++++++ 8 files changed, 250 insertions(+) create mode 100644 query_relationship/.gitignore create mode 100644 query_relationship/data/getdata.sh create mode 100644 query_relationship/data/test.list create mode 100644 query_relationship/dataprovider.py create mode 100644 query_relationship/evaluate.py create mode 100644 query_relationship/predict.sh create mode 100644 query_relationship/train.sh create mode 100644 query_relationship/trainer_config.py diff --git a/query_relationship/.gitignore b/query_relationship/.gitignore new file mode 100644 index 0000000..f88c476 --- /dev/null +++ b/query_relationship/.gitignore @@ -0,0 +1,8 @@ +data/MQ2007 +data/train.list +data/vali.list +data/pred.list +*.log +output/ +*.pyc +core diff --git a/query_relationship/data/getdata.sh b/query_relationship/data/getdata.sh new file mode 100644 index 0000000..68c4a72 --- /dev/null +++ b/query_relationship/data/getdata.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +DIR="$(cd "$(dirname "$0")" ; pwd -P )" +cd $DIR + +#Download MQ2007 dataset +echo "Downloading query-docs data..." +wget http://research.microsoft.com/en-us/um/beijing/projects/letor/LETOR4.0/Data/MQ2007.rar + +#Extract package +echo "Unzipping..." +unrar x MQ2007.rar + +#Remove compressed package +rm MQ2007.rar + +echo "data/MQ2007/Fold1/train.txt" > train.list +echo "data/MQ2007/Fold1/vali.txt" > test.list +echo "data/MQ2007/Fold1/test.txt" > pred.list + +echo "Done." diff --git a/query_relationship/data/test.list b/query_relationship/data/test.list new file mode 100644 index 0000000..b761331 --- /dev/null +++ b/query_relationship/data/test.list @@ -0,0 +1 @@ +data/MQ2007/Fold1/vali.txt diff --git a/query_relationship/dataprovider.py b/query_relationship/dataprovider.py new file mode 100644 index 0000000..7d41f31 --- /dev/null +++ b/query_relationship/dataprovider.py @@ -0,0 +1,70 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer.PyDataProvider2 import * + + +#Define a data provider for "query relationship" +@provider( + input_types={ + 'features1': dense_vector(46), + 'features2': dense_vector(46), + 'label': dense_vector(1) + }, + should_shuffle=False, + cache=CacheType.CACHE_PASS_IN_MEM) +def process(settings, file_name): + with open(file_name) as f: + pre_qid = -1 + feats1 = [] + feats2 = [] + l1 = 0 + l2 = 0 + for line in f: + line = line.split('#')[0] + if len(line.split()) < 48: + continue + qid = int(line.split()[1].split(':')[1]) + if pre_qid != qid: + feats1 = [] + for term in line.split()[2:48]: + feats1.append(float(term.split(':')[1])) + l1 = int(line.split()[0]) + pre_qid = qid + feats2 = feats1 + yield feats1, feats2, [0.5] + else: + feats1 = feats2 + feats2 = [] + l1 = l2 + for term in line.split()[2:48]: + feats2.append(float(term.split(':')[1])) + l2 = int(line.split()[0]) + p12 = 0.5 + if l1 > l2: + p12 = 1 + if l1 < l2: + p12 = 0 + yield feats1, feats2, [p12] + + +@provider(input_types={'features': dense_vector(46)}) +def process_predict(settings, file_name): + with open(file_name) as f: + for line in f: + feats = [] + line = line.split('#')[0] + for term in line.split()[2:48]: + feats.append(float(term.split(':')[1])) + yield feats diff --git a/query_relationship/evaluate.py b/query_relationship/evaluate.py new file mode 100644 index 0000000..d1a0211 --- /dev/null +++ b/query_relationship/evaluate.py @@ -0,0 +1,35 @@ +#!/usr/bin/python +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import re +import math + + +def get_best_pass(log_filename): + with open(log_filename, 'r') as f: + text = f.read() + pattern = re.compile('Test.*? cost=([0-9]+\.[0-9]+).*?pass-([0-9]+)', + re.S) + results = re.findall(pattern, text) + sorted_results = sorted(results, key=lambda result: -float(result[0])) + return sorted_results[0] + + +log_filename = sys.argv[1] +log = get_best_pass(log_filename) +print 'Best pass is %s, rank-cost is %s' % (log[1], log[0]) + +evaluate_pass = "output/pass-%s" % log[1] +print "evaluating from pass %s" % evaluate_pass diff --git a/query_relationship/predict.sh b/query_relationship/predict.sh new file mode 100644 index 0000000..9de65f0 --- /dev/null +++ b/query_relationship/predict.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2016 PaddlePaddle Authors, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +# pass choice +model="output/pass-00001" +paddle train \ + --config=trainer_config.py \ + --use_gpu=false \ + --job=test \ + --init_model_path=$model \ + --config_args=is_predict=1 \ + --predict_output_dir=./ diff --git a/query_relationship/train.sh b/query_relationship/train.sh new file mode 100644 index 0000000..71c9af6 --- /dev/null +++ b/query_relationship/train.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +paddle train \ + --config=trainer_config.py \ + --save_dir=./output \ + --use_gpu=false \ + --trainer_count=1 \ + --test_all_data_in_one_period=true \ + --num_passes=10 \ + --log_period=500 2>&1 | tee 'train.log' diff --git a/query_relationship/trainer_config.py b/query_relationship/trainer_config.py new file mode 100644 index 0000000..806740e --- /dev/null +++ b/query_relationship/trainer_config.py @@ -0,0 +1,52 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +is_predict = get_config_arg('is_predict', bool, False) +trn = 'data/train.list' if not is_predict else None +tst = 'data/test.list' if not is_predict else 'data/pred.list' +process = 'process' if not is_predict else 'process_predict' + +# 1. read data +define_py_data_sources2( + train_list=trn, test_list=tst, module='dataprovider', obj=process) + +# 2. learning algorithm +batch_size = 5 if not is_predict else 1 +settings( + batch_size=batch_size, + learning_rate=1e-3, + learning_method=RMSPropOptimizer()) + +# 3. Network configuration +feature_num = 46 +hid_num = 6 +if not is_predict: + x1 = data_layer(name='features1', size=feature_num) + x2 = data_layer(name='features2', size=feature_num) + y = data_layer(name='label', size=1) + hidden1 = fc_layer( + name='hidden', input=x1, size=hid_num, act=LinearActivation()) + hidden2 = LayerOutput('hidden', LayerType.FC_LAYER, x2, size=feature_num) + y1 = fc_layer(name='output', input=hidden1, size=1, act=SigmoidActivation()) + y2 = LayerOutput('output', LayerType.FC_LAYER, hidden2, size=1) + outputs(rank_cost(left=y1, right=y2, label=y)) +else: + x = data_layer(name='features', size=feature_num) + hidden = fc_layer( + name='hidden', input=x, size=hid_num, act=LinearActivation()) + y_pred = fc_layer( + name='output', input=hidden, size=1, act=SigmoidActivation()) + outputs(y_pred) -- GitLab