水塘抽样.md 10.5 KB
Newer Older
1 2 3 4
# 随机算法之水塘抽样算法

<p align='center'>
<a href="https://github.com/labuladong/fucking-algorithm" target="view_window"><img alt="GitHub" src="https://img.shields.io/github/stars/labuladong/fucking-algorithm?label=Stars&style=flat-square&logo=GitHub"></a>
L
labuladong 已提交
5
<a href="https://appktavsiei5995.pc.xiaoe-tech.com/index" target="_blank"><img class="my_header_icon" src="https://img.shields.io/static/v1?label=精品课程&message=查看&color=pink&style=flat"></a>
6 7 8 9
<a href="https://www.zhihu.com/people/labuladong"><img src="https://img.shields.io/badge/%E7%9F%A5%E4%B9%8E-@labuladong-000000.svg?style=flat-square&logo=Zhihu"></a>
<a href="https://space.bilibili.com/14089380"><img src="https://img.shields.io/badge/B站-@labuladong-000000.svg?style=flat-square&logo=Bilibili"></a>
</p>

L
labuladong 已提交
10
![](https://labuladong.github.io/algo/images/souyisou1.png)
11

L
labuladong 已提交
12
**通知:[数据结构精品课](https://aep.h5.xeknow.com/s/1XJHEO) 已更新到 V2.0;[第 13 期刷题打卡](https://mp.weixin.qq.com/s/eUG2OOzY3k_ZTz-CFvtv5Q) 最后一天报名!另外,建议你在我的 [网站](https://labuladong.github.io/algo/) 学习文章,体验更好。**
13 14 15



L
labuladong 已提交
16 17 18 19 20 21
读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

| LeetCode | 力扣 | 难度 |
| :----: | :----: | :----: |
| [382. Linked List Random Node](https://leetcode.com/problems/linked-list-random-node/) | [382. 链表随机节点](https://leetcode.cn/problems/linked-list-random-node/) | 🟠
| [398. Random Pick Index](https://leetcode.com/problems/random-pick-index/) | [398. 随机数索引](https://leetcode.cn/problems/random-pick-index/) | 🟠
22 23 24

**-----------**

L
labuladong 已提交
25
我最近在力扣上做到两道非常有意思的题目,382 和 398 题,关于水塘抽样算法(Reservoir Sampling),本质上是一种随机概率算法,解法应该说会者不难,难者不会。
L
labuladong 已提交
26 27 28 29 30 31 32 33 34 35 36

我第一次见到这个算法问题是谷歌的一道算法题:给你一个**未知长度**的链表,请你设计一个算法,**只能遍历一次**,随机地返回链表中的一个节点。

这里说的随机是均匀随机(uniform random),也就是说,如果有 `n` 个元素,每个元素被选中的概率都是 `1/n`,不可以有统计意义上的偏差。

一般的想法就是,我先遍历一遍链表,得到链表的总长度 `n`,再生成一个 `[1,n]` 之间的随机数为索引,然后找到索引对应的节点,不就是一个随机的节点了吗?

但题目说了,只能遍历一次,意味着这种思路不可行。题目还可以再泛化,给一个未知长度的序列,如何在其中随机地选择 `k` 个元素?想要解决这个问题,就需要著名的水塘抽样算法了。

### 算法实现

L
labuladong 已提交
37
**先解决只抽取一个元素的问题**,这个问题的难点在于,随机选择是「动态」的,比如说你现在你有 5 个元素,你已经随机选取了其中的某个元素 `a` 作为结果,但是现在再给你一个新元素 `b`,你应该留着 `a` 还是将 `b` 作为结果呢?以什么逻辑做出的选择,才能保证你的选择方法在概率上是公平的呢?
L
labuladong 已提交
38 39 40 41 42 43 44 45 46 47 48

**先说结论,当你遇到第 `i` 个元素时,应该有 `1/i` 的概率选择该元素,`1 - 1/i` 的概率保持原有的选择**。看代码容易理解这个思路:

```java
/* 返回链表中一个随机节点的值 */
int getRandom(ListNode head) {
    Random r = new Random();
    int i = 0, res = 0;
    ListNode p = head;
    // while 循环遍历链表
    while (p != null) {
L
labuladong 已提交
49
        i++;
L
labuladong 已提交
50 51
        // 生成一个 [0, i) 之间的整数
        // 这个整数等于 0 的概率就是 1/i
L
labuladong 已提交
52
        if (0 == r.nextInt(i)) {
L
labuladong 已提交
53 54 55 56 57 58 59 60 61 62 63 64
            res = p.val;
        }
        p = p.next;
    }
    return res;
}
```

对于概率算法,代码往往都是很浅显的,但是这种问题的关键在于证明,你的算法为什么是对的?为什么每次以 `1/i` 的概率更新结果就可以保证结果是平均随机(uniform random)?

**证明**:假设总共有 `n` 个元素,我们要的随机性无非就是每个元素被选择的概率都是 `1/n` 对吧,那么对于第 `i` 个元素,它被选择的概率就是:

L
labuladong 已提交
65
![](https://labuladong.github.io/algo/images/水塘抽样/formula1.png)
L
labuladong 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102

`i` 个元素被选择的概率是 `1/i`,第 `i+1` 次不被替换的概率是 `1 - 1/(i+1)`,以此类推,相乘就是第 `i` 个元素最终被选中的概率,就是 `1/n`

因此,该算法的逻辑是正确的。

**同理,如果要随机选择 `k` 个数,只要在第 `i` 个元素处以 `k/i` 的概率选择该元素,以 `1 - k/i` 的概率保持原有选择即可**。代码如下:

```java
/* 返回链表中 k 个随机节点的值 */
int[] getRandom(ListNode head, int k) {
    Random r = new Random();
    int[] res = new int[k];
    ListNode p = head;

    // 前 k 个元素先默认选上
    for (int j = 0; j < k && p != null; j++) {
        res[j] = p.val;
        p = p.next;
    }

    int i = k;
    // while 循环遍历链表
    while (p != null) {
        // 生成一个 [0, i) 之间的整数
        int j = r.nextInt(++i);
        // 这个整数小于 k 的概率就是 k/i
        if (j < k) {
            res[j] = p.val;
        }
        p = p.next;
    }
    return res;
}
```

对于数学证明,和上面区别不大:

L
labuladong 已提交
103
![](https://labuladong.github.io/algo/images/水塘抽样/formula2.png)
L
labuladong 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

因为虽然每次更新选择的概率增大了 `k` 倍,但是选到具体第 `i` 个元素的概率还是要乘 `1/k`,也就回到了上一个推导。

### 拓展延伸

以上的抽样算法时间复杂度是 O(n),但不是最优的方法,更优化的算法基于几何分布(geometric distribution),时间复杂度为 O(k + klog(n/k))。由于涉及的数学知识比较多,这里就不列出了,有兴趣的读者可以自行搜索一下。

还有一种思路是基于「Fisher–Yates 洗牌算法」的。随机抽取 `k` 个元素,等价于对所有元素洗牌,然后选取前 `k` 个。只不过,洗牌算法需要对元素的随机访问,所以只能对数组这类支持随机存储的数据结构有效。

另外有一种思路也比较有启发意义:给每一个元素关联一个随机数,然后把每个元素插入一个容量为 `k` 的二叉堆(优先级队列)按照配对的随机数进行排序,最后剩下的 `k` 个元素也是随机的。

这个方案看起来似乎有点多此一举,因为插入二叉堆需要 O(logk) 的时间复杂度,所以整个抽样算法就需要 O(nlogk) 的复杂度,还不如我们最开始的算法。但是,这种思路可以指导我们解决**加权随机抽样算法**,权重越高,被随机选中的概率相应增大,这种情况在现实生活中是很常见的,比如你不往游戏里充钱,就永远抽不到皮肤。

最后,我想说随机算法虽然不多,但其实很有技巧的,读者不妨思考两个常见且看起来很简单的问题:

1、如何对带有权重的样本进行加权随机抽取?比如给你一个数组 `w`,每个元素 `w[i]` 代表权重,请你写一个算法,按照权重随机抽取索引。比如 `w = [1,99]`,算法抽到索引 0 的概率是 1%,抽到索引 1 的概率是 99%。

L
labuladong 已提交
121 122
答案见 [我的这篇文章](https://labuladong.github.io/article/fname.html?fname=随机权重)

L
labuladong 已提交
123 124
2、实现一个生成器类,构造函数传入一个很长的数组,请你实现 `randomGet` 方法,每次调用随机返回数组中的一个元素,多次调用不能重复返回相同索引的元素。要求不能对该数组进行任何形式的修改,且操作的时间复杂度是 O(1)。

L
labuladong 已提交
125
答案见 [我的这篇文章](https://labuladong.github.io/article/fname.html?fname=随机集合)
L
labuladong 已提交
126

L
labuladong 已提交
127 128 129 130




131
**_____________**
L
labuladong 已提交
132

L
labuladong 已提交
133
**《labuladong 的算法小抄》已经出版,关注公众号查看详情;后台回复关键词「**进群**」可加入算法群;回复「**全家桶**」可下载配套 PDF 和刷题全家桶**
L
labuladong 已提交
134 135

![](https://labuladong.github.io/algo/images/souyisou2.png)
L
labuladong 已提交
136 137


B
brucecat 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
======其他语言代码======

[382.链表随机节点](https://leetcode-cn.com/problems/linked-list-random-node)

[398.随机数索引](https://leetcode-cn.com/problems/random-pick-index)

### javascript

题目传送门:[382. 链表随机节点](https://leetcode-cn.com/problems/linked-list-random-node/)

返回链表中一个随机节点的值

```js
/**
 * Definition for singly-linked list.
 * function ListNode(val, next) {
 *     this.val = (val===undefined ? 0 : val)
 *     this.next = (next===undefined ? null : next)
 * }
 */
/**
 * @param head The linked list's head.
 Note that the head is guaranteed to be not null, so it contains at least one node.
 * @param {ListNode} head
 */
var Solution = function (head) {
    this.head = head;
};

/**
 * Returns a random node's value.
 * @return {number}
 */
Solution.prototype.getRandom = function () {
    let i = 0, res = 0;
    let p = this.head;

    // while循环遍历链表
    while (p != null) {
        // 生成一个 [0, i) 之间的整数
        // 这个整数等于 0 的概率就是 1/i
        if (Math.floor(Math.random()*(++i)) === 0) {
            res = p.val;
        }
        p = p.next;
    }
    return res;
};

/**
 * Your Solution object will be instantiated and called as such:
 * var obj = new Solution(head)
 * var param_1 = obj.getRandom()
 */
```



**[题目传送门:398. 随机数索引](https://leetcode-cn.com/problems/random-pick-index/)**

假设当前正要读取第n个数据,则我们以1/n的概率留下该数据,否则以(n-1)/n 的概率留下前n-1个数据中的一个。
而前n-1个数组留下的那个概率为1/(n-1),因此最终留下上次n-1个数中留下的那个数的概率为[1/(n-1)]*[(n-1)/n] = 1/n,符合均匀分布的要求。

```js
/**
 * @param {number[]} nums
 */
var Solution = function (nums) {
    this.nums = nums;
    // 需要k个数据,在本题中,k等于1
    this.need = 1;
};

// 假设当前正要读取第n个数据,则我们以1/n的概率留下该数据,否则以(n-1)/n 的概率留下前n-1个数据中的一个。
// 而前n-1个数组留下的那个概率为1/(n-1),
// 因此最终留下上次n-1个数中留下的那个数的概率为[1/(n-1)]*[(n-1)/n] = 1/n,符合均匀分布的要求
//

Solution.prototype.pick = function (target) {
    let count = 0
    let res = 0

    this.nums.forEach((value, key) => {
        if (value === target) {
            //我们的目标对象中选取。
            count++;
            //我们以1/n的概率留下该数据
            if (Math.floor(Math.random() * count) === 0) {
                res = key;
            }
        }
    })
    return res
};
```
L
labuladong 已提交
233