{ "cells": [ { "cell_type": "markdown", "id": "a3deebdc", "metadata": {}, "source": [ "# Cross-Encoder for Quora Duplicate Questions Detection\n", "This model was trained using [SentenceTransformers](https://sbert.net) [Cross-Encoder](https://www.sbert.net/examples/applications/cross-encoder/README.html) class.\n" ] }, { "cell_type": "markdown", "id": "4fc17643", "metadata": {}, "source": [ "## Training Data\n", "This model was trained on the [STS benchmark dataset](http://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark). The model will predict a score between 0 and 1 how for the semantic similarity of two sentences.\n" ] }, { "cell_type": "markdown", "id": "f66fb11e", "metadata": {}, "source": [ "## Usage and Performance\n" ] }, { "cell_type": "markdown", "id": "fd12128b", "metadata": {}, "source": [ "Pre-trained models can be used like this:\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d0d04e39", "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade paddlenlp" ] }, { "cell_type": "code", "execution_count": 1, "id": "d07e31aa", "metadata": { "collapsed": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/root/miniconda3/envs/paddle/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "\u001b[32m[2022-11-21 02:38:07,127] [ INFO]\u001b[0m - Downloading model_config.json from https://bj.bcebos.com/paddlenlp/models/community/cross-encoder/stsb-TinyBERT-L-4/model_config.json\u001b[0m\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 432/432 [00:00<00:00, 425kB/s]\n", "\u001b[32m[2022-11-21 02:38:07,197] [ INFO]\u001b[0m - We are using to load 'cross-encoder/stsb-TinyBERT-L-4'.\u001b[0m\n", "\u001b[32m[2022-11-21 02:38:07,198] [ INFO]\u001b[0m - Downloading https://bj.bcebos.com/paddlenlp/models/community/cross-encoder/stsb-TinyBERT-L-4/model_state.pdparams and saved to /root/.paddlenlp/models/cross-encoder/stsb-TinyBERT-L-4\u001b[0m\n", "\u001b[32m[2022-11-21 02:38:07,198] [ INFO]\u001b[0m - Downloading model_state.pdparams from https://bj.bcebos.com/paddlenlp/models/community/cross-encoder/stsb-TinyBERT-L-4/model_state.pdparams\u001b[0m\n", "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 54.8M/54.8M [00:00<00:00, 64.7MB/s]\n", "\u001b[32m[2022-11-21 02:38:08,199] [ INFO]\u001b[0m - Already cached /root/.paddlenlp/models/cross-encoder/stsb-TinyBERT-L-4/model_config.json\u001b[0m\n", "W1121 02:38:08.202270 64563 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.2, Runtime API Version: 10.2\n", "W1121 02:38:08.207437 64563 gpu_resources.cc:91] device: 0, cuDNN Version: 7.6.\n", "\u001b[32m[2022-11-21 02:38:09,661] [ INFO]\u001b[0m - Weights from pretrained model not used in BertModel: ['classifier.weight', 'classifier.bias']\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(Tensor(shape=[1, 20, 312], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n", " [[[-0.73827386, -0.57349819, 0.47456041, ..., -0.07317579,\n", " 0.23808761, -0.43587247],\n", " [-0.71079123, -0.37019217, 0.44499084, ..., -0.07541266,\n", " 0.22209664, -0.48883811],\n", " [-0.61283624, 0.01138088, 0.46346331, ..., -0.15316986,\n", " 0.38455290, -0.23527470],\n", " ...,\n", " [-0.19267607, -0.42171016, 0.40080610, ..., -0.04322027,\n", " 0.16102640, -0.43728969],\n", " [-0.76348048, 0.00028179, 0.50795513, ..., 0.02495949,\n", " 0.32419923, -0.44668996],\n", " [-0.72070849, -0.48510927, 0.47747549, ..., -0.01621611,\n", " 0.31407145, -0.38287419]]]), Tensor(shape=[1, 312], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n", " [[ 0.38359359, 0.16227540, -0.58949089, -0.67293817, 0.70552814,\n", " 0.74028063, -0.60770833, 0.50480992, 0.71489060, -0.73976040,\n", " -0.11784898, 0.73014355, -0.65726435, 0.17490843, -0.44103470,\n", " 0.62014306, 0.35533482, -0.44271812, -0.61711168, -0.70586687,\n", " 0.69903672, 0.00862758, 0.69424403, 0.31887573, 0.38736165,\n", " 0.02848060, -0.69896543, 0.69952166, 0.56477094, 0.68585342,\n", " 0.66026199, 0.67826200, 0.67839348, 0.74852920, -0.04272985,\n", " 0.76357287, 0.38685408, -0.69717598, 0.69945419, 0.44048944,\n", " -0.66915488, 0.11735962, 0.37215349, 0.73054057, 0.71345085,\n", " 0.66489315, 0.19956835, 0.71552449, 0.64762783, -0.46583632,\n", " -0.09976894, -0.45265704, 0.54242563, 0.42835563, -0.60076892,\n", " 0.69768012, -0.72207040, -0.52898210, 0.34657273, 0.05400079,\n", " 0.57360554, -0.72731823, -0.71799070, -0.37212241, -0.70602018,\n", " -0.71248102, 0.02778789, -0.73165607, 0.46581894, -0.72120243,\n", " 0.60769719, -0.63354278, 0.75307459, 0.00700274, -0.00984141,\n", " -0.58984685, 0.36321065, 0.60098255, -0.72467339, 0.18362086,\n", " 0.10687865, -0.63730168, -0.62655306, -0.00187578, -0.51795095,\n", " -0.64884937, 0.69950461, 0.72286713, 0.72522557, -0.45434299,\n", " -0.43063730, -0.10669708, -0.51012146, 0.66286671, 0.69542134,\n", " 0.21393165, -0.02928682, 0.67238331, 0.20404275, -0.63556075,\n", " 0.55774790, 0.26141557, 0.70166790, -0.03091500, 0.65226245,\n", " -0.69878876, 0.32701582, -0.68492270, 0.67152256, 0.66395414,\n", " -0.68914133, -0.63889050, 0.71558940, 0.50034380, -0.12911484,\n", " 0.70831281, 0.68631476, -0.41206849, 0.23268108, 0.67747647,\n", " -0.29744238, 0.65135175, -0.70074749, 0.56074560, -0.63501489,\n", " 0.74985635, -0.60603380, 0.66920304, -0.72418481, -0.59756589,\n", " -0.70151484, -0.38735744, -0.66458094, -0.71190053, -0.69316322,\n", " 0.43108079, -0.21692288, 0.70705998, -0.14984211, 0.75786442,\n", " 0.69729054, -0.68925959, -0.46773866, 0.66707891, -0.07957093,\n", " 0.73757517, 0.10062494, -0.73353016, 0.10992812, -0.48824292,\n", " 0.62493157, 0.43311006, -0.15723324, -0.48392498, -0.65230477,\n", " -0.41098344, -0.65238249, -0.41507134, -0.55544889, -0.32195652,\n", " -0.74827588, -0.64071310, -0.49207535, -0.69750905, -0.57037342,\n", " 0.35724813, 0.74778593, 0.49369636, -0.69870174, 0.24547403,\n", " 0.73229605, 0.15653144, 0.41334581, 0.64413625, 0.53084993,\n", " -0.64746642, -0.58720803, 0.63381183, 0.76515305, -0.68342912,\n", " 0.65923864, -0.74662960, -0.72339952, 0.32203752, -0.63402468,\n", " -0.71399093, -0.50430977, 0.26967043, -0.21176267, 0.65678287,\n", " 0.09193933, 0.23962519, 0.59481263, -0.61463839, -0.28634411,\n", " 0.69451737, 0.47513142, 0.30889973, -0.18030594, -0.50777411,\n", " 0.71548641, -0.34869543, -0.01252351, 0.12018032, 0.69536412,\n", " 0.53745425, 0.54889160, -0.10619923, 0.68386155, -0.68498713,\n", " 0.23352134, 0.67296249, -0.12094481, -0.69636226, -0.06552890,\n", " 0.00965041, -0.52394331, 0.72305930, -0.17239039, -0.73262835,\n", " 0.50841606, 0.39529455, -0.70830429, 0.51234418, 0.68391299,\n", " -0.72483873, -0.51841038, -0.58264560, -0.74197364, 0.46386808,\n", " -0.23263671, 0.21232133, -0.69674802, 0.33948907, 0.75922930,\n", " -0.43505231, -0.53149903, -0.65927148, 0.09607304, -0.68945718,\n", " 0.66966355, 0.68096715, 0.66396469, 0.13001618, -0.68894261,\n", " -0.66597682, 0.61407733, 0.69670630, 0.63995171, 0.33257753,\n", " 0.66776848, 0.57427299, 0.32768273, 0.69438887, 0.41346189,\n", " -0.71529591, -0.09860074, -0.72291893, 0.16860481, -0.67641008,\n", " 0.70644248, -0.24303547, 0.28892463, 0.56054235, 0.55539572,\n", " 0.70762485, -0.50166684, -0.70544142, -0.74241722, -0.74010289,\n", " 0.70217764, -0.09219251, 0.47989756, -0.17431454, 0.76019192,\n", " -0.09623899, -0.64994997, -0.03216666, 0.70323825, -0.66661566,\n", " 0.71163839, -0.08982500, -0.35390857, 0.61377501, -0.49430367,\n", " 0.49526611, 0.75078416, -0.05324765, -0.75398672, 0.70934319,\n", " 0.21146417, -0.59094489, 0.39163795, -0.67382598, -0.63484156,\n", " -0.27295890, 0.75101918, 0.70603085, 0.71781063, -0.57344818,\n", " -0.22560060, -0.62196493, 0.68178481, 0.61596531, -0.12730023,\n", " -0.69500911, 0.73689735, 0.12627751, -0.26101601, -0.24929181,\n", " 0.68093145, 0.05896470]]))\n" ] } ], "source": [ "import paddle\n", "from paddlenlp.transformers import BertForSequenceClassification\n", "\n", "model = BertForSequenceClassification.from_pretrained(\"cross-encoder/stsb-TinyBERT-L-4\")\n", "input_ids = paddle.randint(100, 200, shape=[1, 20])\n", "print(model(input_ids))" ] }, { "cell_type": "markdown", "id": "aeccdfe1", "metadata": {}, "source": [ "> 此模型介绍及权重来源于[https://huggingface.co/cross-encoder/stsb-TinyBERT-L-4](https://huggingface.co/cross-encoder/stsb-TinyBERT-L-4),并转换为飞桨模型格式。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 5 }