diff --git a/models/match/match-pyramid/config.yaml b/models/match/match-pyramid/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c21aef87b865726e93645eb6d733aaa385dec5c --- /dev/null +++ b/models/match/match-pyramid/config.yaml @@ -0,0 +1,89 @@ +# Copyrigh t(c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +workspace: "paddlerec.models.match.match-pyramid" + +dataset: +- name: dataset_train + batch_size: 128 + type: DataLoader + data_path: "{workspace}/data/train" + data_converter: "{workspace}/train_reader.py" +- name: dataset_infer + batch_size: 1 + type: DataLoader + data_path: "{workspace}/data/test" + data_converter: "{workspace}/test_reader.py" + + +hyper_parameters: + optimizer: + class: adam + learning_rate: 0.001 + strategy: async + emb_path: "./data/embedding.npy" + sentence_left_size: 20 + sentence_right_size: 500 + vocab_size: 193368 + emb_size: 50 + kernel_num: 8 + hidden_size: 20 + hidden_act: "relu" + out_size: 1 + channels: 1 + conv_filter: [2,10] + conv_act: "relu" + pool_size: [6,50] + pool_stride: [6,50] + pool_type: "max" + pool_padding: "VALID" + +mode: [train_runner , infer_runner] +# config of each runner. +# runner is a kind of paddle training class, which wraps the train/infer process. +runner: +- name: train_runner + class: train + # num of epochs + epochs: 2 + # device to run training or infer + device: cpu + save_checkpoint_interval: 1 # save model interval of epochs + save_inference_interval: 1 # save inference + save_checkpoint_path: "inference" # save checkpoint path + save_inference_path: "inference" # save inference path + save_inference_feed_varnames: [] # feed vars of save inference + save_inference_fetch_varnames: [] # fetch vars of save inference + init_model_path: "" # load model path + print_interval: 2 + phases: phase_train +- name: infer_runner + class: infer + # device to run training or infer + device: cpu + print_interval: 1 + init_model_path: "inference/1" # load model path + phases: phase_infer + +# runner will run all the phase in each epoch +phase: +- name: phase_train + model: "{workspace}/model.py" # user-defined model + dataset_name: dataset_train # select dataset by name + thread_num: 1 +- name: phase_infer + model: "{workspace}/model.py" # user-defined model + dataset_name: dataset_infer # select dataset by name + thread_num: 1 diff --git a/models/match/match-pyramid/data_process.sh b/models/match/match-pyramid/data_process.sh new file mode 100644 index 0000000000000000000000000000000000000000..dfd3a8748a98aecdc7e89fb8f4c461740286f8e4 --- /dev/null +++ b/models/match/match-pyramid/data_process.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +echo "...........load data................." +wget --no-check-certificate 'https://paddlerec.bj.bcebos.com/match_pyramid/match_pyramid_data.tar.gz' +mv ./match_pyramid_data.tar.gz ./data +rm -rf ./data/relation.test.fold1.txt ./data/realtion.train.fold1.txt +tar -xvf ./data/match_pyramid_data.tar.gz +echo "...........data process..............." +python ./data/process.py diff --git a/models/match/match-pyramid/eval.py b/models/match/match-pyramid/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..dae40cef13943051bc993327cbdaf39486d2b48f --- /dev/null +++ b/models/match/match-pyramid/eval.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import numpy as np + + +def eval_MAP(pred, gt): + map_value = 0.0 + r = 0.0 + c = list(zip(pred, gt)) + random.shuffle(c) + c = sorted(c, key=lambda x: x[0], reverse=True) + for j, (p, g) in enumerate(c): + if g != 0: + r += 1 + map_value += r / (j + 1.0) + if r == 0: + return 0.0 + else: + return map_value / r + + +filename = './data/relation.test.fold1.txt' +gt = [] +qid = [] +f = open(filename, "r") +f.readline() +num = 0 +for line in f.readlines(): + num = num + 1 + line = line.strip().split() + gt.append(int(line[0])) + qid.append(line[1]) +f.close() +print(num) +filename = './result.txt' +pred = [] +for line in open(filename): + line = line.strip().split(",") + line[1] = line[1].split(":") + line = line[1][1].strip(" ") + line = line.strip("[") + line = line.strip("]") + pred.append(float(line)) + +result_dict = {} +for i in range(len(qid)): + if qid[i] not in result_dict: + result_dict[qid[i]] = [] + result_dict[qid[i]].append([gt[i], pred[i]]) +print(len(result_dict)) + +map = 0 +for qid in result_dict: + gt = np.array(result_dict[qid])[:, 0] + pred = np.array(result_dict[qid])[:, 1] + map += eval_MAP(pred, gt) +map = map / len(result_dict) + +print("map=", map) diff --git a/models/match/match-pyramid/model.py b/models/match/match-pyramid/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6abd1503ab29a851f93378e8a51f31b7c84a2225 --- /dev/null +++ b/models/match/match-pyramid/model.py @@ -0,0 +1,142 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import random +import numpy as np +import paddle +import paddle.fluid as fluid +from paddlerec.core.utils import envs +from paddlerec.core.model import ModelBase + + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def _init_hyper_parameters(self): + self.emb_path = envs.get_global_env("hyper_parameters.emb_path") + self.sentence_left_size = envs.get_global_env( + "hyper_parameters.sentence_left_size") + self.sentence_right_size = envs.get_global_env( + "hyper_parameters.sentence_right_size") + self.vocab_size = envs.get_global_env("hyper_parameters.vocab_size") + self.emb_size = envs.get_global_env("hyper_parameters.emb_size") + self.kernel_num = envs.get_global_env("hyper_parameters.kernel_num") + self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size") + self.hidden_act = envs.get_global_env("hyper_parameters.hidden_act") + self.out_size = envs.get_global_env("hyper_parameters.out_size") + self.channels = envs.get_global_env("hyper_parameters.channels") + self.conv_filter = envs.get_global_env("hyper_parameters.conv_filter") + self.conv_act = envs.get_global_env("hyper_parameters.conv_act") + self.pool_size = envs.get_global_env("hyper_parameters.pool_size") + self.pool_stride = envs.get_global_env("hyper_parameters.pool_stride") + self.pool_type = envs.get_global_env("hyper_parameters.pool_type") + self.pool_padding = envs.get_global_env( + "hyper_parameters.pool_padding") + + def input_data(self, is_infer=False, **kwargs): + sentence_left = fluid.data( + name="sentence_left", + shape=[-1, self.sentence_left_size, 1], + dtype='int64', + lod_level=0) + sentence_right = fluid.data( + name="sentence_right", + shape=[-1, self.sentence_right_size, 1], + dtype='int64', + lod_level=0) + return [sentence_left, sentence_right] + + def embedding_layer(self, input): + """ + embedding layer + """ + if os.path.isfile(self.emb_path): + embedding_array = np.load(self.emb_path) + emb = fluid.layers.embedding( + input=input, + size=[self.vocab_size, self.emb_size], + padding_idx=0, + param_attr=fluid.ParamAttr( + name="word_embedding", + initializer=fluid.initializer.NumpyArrayInitializer( + embedding_array))) + else: + emb = fluid.layers.embedding( + input=input, + size=[self.vocab_size, self.emb_size], + padding_idx=0, + param_attr=fluid.ParamAttr( + name="word_embedding", + initializer=fluid.initializer.Xavier())) + + return emb + + def conv_pool_layer(self, input): + """ + convolution and pool layer + """ + # data format NCHW + # same padding + conv = fluid.layers.conv2d( + input=input, + num_filters=self.kernel_num, + stride=1, + padding="SAME", + filter_size=self.conv_filter, + act=self.conv_act) + pool = fluid.layers.pool2d( + input=conv, + pool_size=self.pool_size, + pool_stride=self.pool_stride, + pool_type=self.pool_type, + pool_padding=self.pool_padding) + return pool + + def net(self, inputs, is_infer=False): + left_emb = self.embedding_layer(inputs[0]) + right_emb = self.embedding_layer(inputs[1]) + cross = fluid.layers.matmul(left_emb, right_emb, transpose_y=True) + cross = fluid.layers.reshape(cross, + [-1, 1, cross.shape[1], cross.shape[2]]) + conv_pool = self.conv_pool_layer(input=cross) + relu_hid = fluid.layers.fc(input=conv_pool, + size=self.hidden_size, + act=self.hidden_act) + prediction = fluid.layers.fc( + input=relu_hid, + size=self.out_size, ) + + if is_infer: + self._infer_results["prediction"] = prediction + return + + pos = fluid.layers.slice( + prediction, axes=[0, 1], starts=[0, 0], ends=[64, 1]) + neg = fluid.layers.slice( + prediction, axes=[0, 1], starts=[64, 0], ends=[128, 1]) + loss_part1 = fluid.layers.elementwise_sub( + fluid.layers.fill_constant( + shape=[64, 1], value=1.0, dtype='float32'), + pos) + loss_part2 = fluid.layers.elementwise_add(loss_part1, neg) + loss_part3 = fluid.layers.elementwise_max( + fluid.layers.fill_constant( + shape=[64, 1], value=0.0, dtype='float32'), + loss_part2) + + avg_cost = fluid.layers.mean(loss_part3) + self._cost = avg_cost diff --git a/models/match/match-pyramid/run.sh b/models/match/match-pyramid/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..3eccc10a990d563ed1dd5db2ad8ec3a73042ee69 --- /dev/null +++ b/models/match/match-pyramid/run.sh @@ -0,0 +1,6 @@ +#!/bin/bash +echo "................run................." +python -m paddlerec.run -m ./config.yaml >result1.txt +grep -A1 "prediction" ./result1.txt >./result.txt +rm -f result1.txt +python eval.py diff --git a/models/match/match-pyramid/test_reader.py b/models/match/match-pyramid/test_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..2a36ff18f537f1da86189c1c13440086cadc6d74 --- /dev/null +++ b/models/match/match-pyramid/test_reader.py @@ -0,0 +1,39 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +from paddlerec.core.reader import ReaderBase + + +class Reader(ReaderBase): + def init(self): + pass + + def generate_sample(self, line): + """ + Read the data line by line and process it as a dictionary + """ + + def reader(): + """ + This function needs to be implemented by the user, based on data format + """ + + features = line.strip('\n').split('\t') + doc1 = [int(word_id) for word_id in features[0].split(",")] + doc2 = [int(word_id) for word_id in features[1].split(",")] + features_name = ["doc1", "doc2"] + yield zip(features_name, [doc1] + [doc2]) + + return reader diff --git a/models/match/match-pyramid/train_reader.py b/models/match/match-pyramid/train_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..ec9520945ddc5ad4eabcc7a3ac6511b22e495cd8 --- /dev/null +++ b/models/match/match-pyramid/train_reader.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddlerec.core.reader import ReaderBase + + +class Reader(ReaderBase): + def init(self): + pass + + def generate_sample(self, line): + """ + Read the data line by line and process it as a dictionary + """ + + def reader(): + """ + This function needs to be implemented by the user, based on data format + """ + + features = line.strip('\n').split('\t') + doc1 = [int(word_id) for word_id in features[0].split(",")] + doc2 = [int(word_id) for word_id in features[1].split(",")] + features_name = ["doc1", "doc2"] + yield zip(features_name, [doc1] + [doc2]) + + return reader