{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 使用Word2Vec进行文本语义相似度计算\n", "\n", "本示例展示利用PaddleHub“端到端地”完成文本相似度计算\n", "\n", "## 一、准备文本数据\n", "\n", "如\n", "```\n", "驾驶违章一次扣12分用两个驾驶证处理可以吗 一次性扣12分的违章,能用不满十二分的驾驶证扣分吗\n", "水果放冰箱里储存好吗 中国银行纪念币网上怎么预约\n", "电脑反应很慢怎么办 反应速度慢,电脑总是卡是怎么回事\n", "```\n", "\n", "## 二、分词\n", "利用PaddleHub Module LAC对文本数据进行分词" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# coding:utf-8\n", "# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\"\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "\"\"\"similarity between two sentences\"\"\"\n", "\n", "import numpy as np\n", "import scipy\n", "from scipy.spatial import distance\n", "\n", "from paddlehub.reader.tokenization import load_vocab\n", "import paddle.fluid as fluid\n", "import paddlehub as hub\n", "\n", "raw_data = [\n", " [\"驾驶违章一次扣12分用两个驾驶证处理可以吗\", \"一次性扣12分的违章,能用不满十二分的驾驶证扣分吗\"],\n", " [\"水果放冰箱里储存好吗\", \"中国银行纪念币网上怎么预约\"],\n", " [\"电脑反应很慢怎么办\", \"反应速度慢,电脑总是卡是怎么回事\"]\n", "]\n", "\n", "lac = hub.Module(name=\"lac\")\n", "\n", "processed_data = []\n", "for text_pair in raw_data:\n", " inputs = {\"text\" : text_pair}\n", " results = lac.lexical_analysis(data=inputs, use_gpu=True, batch_size=2)\n", " data = []\n", " for result in results:\n", " data.append(\" \".join(result[\"word\"]))\n", " processed_data.append(data)\n", "\n", "processed_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 三、计算文本语义相似度\n", "\n", "将分词文本中的单词相应替换为wordid,之后输入wor2vec module中计算两个文本语义相似度" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def convert_tokens_to_ids(vocab, text):\n", " wids = []\n", " tokens = text.split(\" \")\n", " for token in tokens:\n", " wid = vocab.get(token, None)\n", " if not wid:\n", " wid = vocab[\"unknown\"]\n", " wids.append(wid)\n", " return wids" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "module = hub.Module(name=\"word2vec_skipgram\")\n", "inputs, outputs, program = module.context(trainable=False)\n", "vocab = load_vocab(module.get_vocab_path())\n", "\n", "word_ids = inputs[\"word_ids\"]\n", "embedding = outputs[\"word_embs\"]\n", "\n", "place = fluid.CPUPlace()\n", "exe = fluid.Executor(place)\n", "feeder = fluid.DataFeeder(feed_list=[word_ids], place=place)\n", "\n", "for item in processed_data:\n", " text_a = convert_tokens_to_ids(vocab, item[0])\n", " text_b = convert_tokens_to_ids(vocab, item[1])\n", "\n", " vecs_a, = exe.run(\n", " program,\n", " feed=feeder.feed([[text_a]]),\n", " fetch_list=[embedding.name],\n", " return_numpy=False)\n", " vecs_a = np.array(vecs_a)\n", " vecs_b, = exe.run(\n", " program,\n", " feed=feeder.feed([[text_b]]),\n", " fetch_list=[embedding.name],\n", " return_numpy=False)\n", " vecs_b = np.array(vecs_b)\n", "\n", " sent_emb_a = np.sum(vecs_a, axis=0)\n", " sent_emb_b = np.sum(vecs_b, axis=0)\n", " cos_sim = 1 - distance.cosine(sent_emb_a, sent_emb_b)\n", "\n", " print(\"text_a: %s; text_b: %s; cosine_similarity: %.5f\" %\n", " (item[0], item[1], cos_sim))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }