{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0.0\n" ] } ], "source": [ "import collections\n", "import math\n", "import random\n", "import sys\n", "import time\n", "import os\n", "import torch\n", "import torch.utils.data as Data\n", "\n", "sys.path.append(\"..\") \n", "import d2lzh_pytorch as d2l\n", "print(torch.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['ptb.train.txt',\n", " 'ptb.char.valid.txt',\n", " 'README',\n", " 'ptb.char.train.txt',\n", " 'ptb.test.txt',\n", " 'ptb.char.test.txt',\n", " 'ptb.valid.txt']" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.listdir(\"../../data/ptb\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# sentences: 42068'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with open('../../data/ptb/ptb.train.txt', 'r') as f:\n", " lines = f.readlines()\n", " # st是sentence的缩写\n", " raw_dataset = [st.split() for st in lines]\n", "\n", "'# sentences: %d' % len(raw_dataset)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']\n", "# tokens: 15 ['pierre', '', 'N', 'years', 'old']\n", "# tokens: 11 ['mr.', '', 'is', 'chairman', 'of']\n" ] } ], "source": [ "for st in raw_dataset[:3]:\n", " print('# tokens:', len(st), st[:5])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# tk是token的缩写\n", "counter = collections.Counter([tk for st in raw_dataset for tk in st])\n", "counter = dict(filter(lambda x: x[1] >= 5, counter.items()))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# tokens: 887100'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idx_to_token = [tk for tk, _ in counter.items()]\n", "token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}\n", "dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]\n", " for st in raw_dataset]\n", "num_tokens = sum([len(st) for st in dataset])\n", "'# tokens: %d' % num_tokens" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# tokens: 375184'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def discard(idx):\n", " return random.uniform(0, 1) < 1 - math.sqrt(\n", " 1e-4 / counter[idx_to_token[idx]] * num_tokens)\n", "\n", "subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]\n", "'# tokens: %d' % sum([len(st) for st in subsampled_dataset])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# the: before=50770, after=2031'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def compare_counts(token):\n", " return '# %s: before=%d, after=%d' % (token, sum(\n", " [st.count(token_to_idx[token]) for st in dataset]), sum(\n", " [st.count(token_to_idx[token]) for st in subsampled_dataset]))\n", "\n", "compare_counts('the')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# join: before=45, after=45'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compare_counts('join')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_centers_and_contexts(dataset, max_window_size):\n", " centers, contexts = [], []\n", " for st in dataset:\n", " if len(st) < 2: # 每个句子至少要有2个词才可能组成一对“中心词-背景词”\n", " continue\n", " centers += st\n", " for center_i in range(len(st)):\n", " window_size = random.randint(1, max_window_size)\n", " indices = list(range(max(0, center_i - window_size),\n", " min(len(st), center_i + 1 + window_size)))\n", " indices.remove(center_i) # 将中心词排除在背景词之外\n", " contexts.append([st[idx] for idx in indices])\n", " return centers, contexts" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]\n", "center 0 has contexts [1]\n", "center 1 has contexts [0, 2]\n", "center 2 has contexts [1, 3]\n", "center 3 has contexts [1, 2, 4, 5]\n", "center 4 has contexts [2, 3, 5, 6]\n", "center 5 has contexts [3, 4, 6]\n", "center 6 has contexts [5]\n", "center 7 has contexts [8, 9]\n", "center 8 has contexts [7, 9]\n", "center 9 has contexts [8]\n" ] } ], "source": [ "tiny_dataset = [list(range(7)), list(range(7, 10))]\n", "print('dataset', tiny_dataset)\n", "for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):\n", " print('center', center, 'has contexts', context)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_negatives(all_contexts, sampling_weights, K):\n", " all_negatives, neg_candidates, i = [], [], 0\n", " population = list(range(len(sampling_weights)))\n", " for contexts in all_contexts:\n", " negatives = []\n", " while len(negatives) < len(contexts) * K:\n", " if i == len(neg_candidates):\n", " # 根据每个词的权重(sampling_weights)随机生成k个词的索引作为噪声词。\n", " # 为了高效计算,可以将k设得稍大一点\n", " i, neg_candidates = 0, random.choices(\n", " population, sampling_weights, k=int(1e5))\n", " neg, i = neg_candidates[i], i + 1\n", " # 噪声词不能是背景词\n", " if neg not in set(contexts):\n", " negatives.append(neg)\n", " all_negatives.append(negatives)\n", " return all_negatives\n", "\n", "sampling_weights = [counter[w]**0.75 for w in idx_to_token]\n", "all_negatives = get_negatives(all_contexts, sampling_weights, 5)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def batchify(data):\n", " max_len = max(len(c) + len(n) for _, c, n in data)\n", " centers, contexts_negatives, masks, labels = [], [], [], []\n", " for center, context, negative in data:\n", " cur_len = len(context) + len(negative)\n", " centers += [center]\n", " contexts_negatives += [context + negative + [0] * (max_len - cur_len)]\n", " masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n", " labels += [[1] * len(context) + [0] * (max_len - len(context))]\n", " return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),\n", " torch.tensor(masks), torch.tensor(labels))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "'list' object has no attribute 'size'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m512\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mnum_workers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplatform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstartswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'win32'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mData\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensorDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_centers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mall_contexts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mall_negatives\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,\n\u001b[1;32m 7\u001b[0m collate_fn=batchify, num_workers=num_workers)\n", "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, *tensors)\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtensor\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataset.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtensor\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'" ] } ], "source": [ "import torch.utils.data as Data\n", "\n", "batch_size = 512\n", "num_workers = 0 if sys.platform.startswith('win32') else 4\n", "dataset = Data.TensorDataset(all_centers, all_contexts, all_negatives)\n", "data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,\n", " collate_fn=batchify, num_workers=num_workers)\n", "for batch in data_iter:\n", " for name, data in zip(['centers', 'contexts_negatives', 'masks',\n", " 'labels'], batch):\n", " print(name, 'shape:', data.shape)\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }