打印素数.md 5.9 KB
Newer Older
L
labuladong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
# 如何高效寻找素数

素数的定义看起来很简单,如果一个数如果只能被 1 和它本身整除,那么这个数就是素数。

不要觉得素数的定义简单,恐怕没多少人真的能把素数相关的算法写得高效。比如让你写这样一个函数:

```java
// 返回区间 [2, n) 中有几个素数 
int countPrimes(int n)

// 比如 countPrimes(10) 返回 4
// 因为 2,3,5,7 是素数
```

你会如何写这个函数?我想大家应该会这样写:

```java
int countPrimes(int n) {
    int count = 0;
    for (int i = 2; i < n; i++)
        if (isPrim(i)) count++;
    return count;
}

// 判断整数 n 是否是素数
boolean isPrime(int n) {
    for (int i = 2; i < n; i++)
        if (n % i == 0)
            // 有其他整除因子
            return false;
    return true;
}
```

这样写的话时间复杂度 O(n^2),问题很大。**首先你用 isPrime 函数来辅助的思路就不够高效;而且就算你要用 isPrime 函数,这样写算法也是存在计算冗余的**

先来简单说下**如果你要判断一个数是不是素数,应该如何写算法**。只需稍微修改一下上面的 isPrim 代码中的 for 循环条件:

```java
boolean isPrime(int n) {
    for (int i = 2; i * i <= n; i++)
        ...
}
```

换句话说,`i` 不需要遍历到 `n`,而只需要到 `sqrt(n)` 即可。为什么呢,我们举个例子,假设 `n = 12`

```java
12 = 2 × 6
12 = 3 × 4
12 = sqrt(12) × sqrt(12)
12 = 4 × 3
12 = 6 × 2
```

可以看到,后两个乘积就是前面两个反过来,反转临界点就在 `sqrt(n)`

换句话说,如果在 `[2,sqrt(n)]` 这个区间之内没有发现可整除因子,就可以直接断定 `n` 是素数了,因为在区间 `[sqrt(n),n]` 也一定不会发现可整除因子。

现在,`isPrime` 函数的时间复杂度降为 O(sqrt(N)),**但是我们实现 `countPrimes` 函数其实并不需要这个函数**,以上只是希望读者明白 `sqrt(n)` 的含义,因为等会还会用到。


### 高效实现 `countPrimes`

高效解决这个问题的核心思路是和上面的常规思路反着来:

首先从 2 开始,我们知道 2 是一个素数,那么 2 × 2 = 4, 3 × 2 = 6, 4 × 2 = 8... 都不可能是素数了。

然后我们发现 3 也是素数,那么 3 × 2 = 6, 3 × 3 = 9, 3 × 4 = 12... 也都不可能是素数了。

看到这里,你是否有点明白这个排除法的逻辑了呢?先看我们的第一版代码:

```java
int countPrimes(int n) {
    boolean[] isPrim = new boolean[n];
    // 将数组都初始化为 true
    Arrays.fill(isPrim, true);

    for (int i = 2; i < n; i++) 
        if (isPrim[i]) 
            // i 的倍数不可能是素数了
            for (int j = 2 * i; j < n; j += i) 
                    isPrim[j] = false;
    
    int count = 0;
    for (int i = 2; i < n; i++)
        if (isPrim[i]) count++;
    
    return count;
}
```

如果上面这段代码你能够理解,那么你已经掌握了整体思路,但是还有两个细微的地方可以优化。

首先,回想刚才判断一个数是否是素数的 `isPrime` 函数,由于因子的对称性,其中的 for 循环只需要遍历 `[2,sqrt(n)]` 就够了。这里也是类似的,我们外层的 for 循环也只需要遍历到 `sqrt(n)`

```java
for (int i = 2; i * i < n; i++) 
    if (isPrim[i]) 
        ...
```

除此之外,很难注意到内层的 for 循环也可以优化。我们之前的做法是:

```java
for (int j = 2 * i; j < n; j += i) 
    isPrim[j] = false;
```

这样可以把 `i` 的整数倍都标记为 `false`,但是仍然存在计算冗余。

比如 `n = 25``i = 4` 时算法会标记 4 × 2 = 8,4 × 3 = 12 等等数字,但是这两个数字已经被 `i = 2``i = 3` 的 2 × 4 和 3 × 4 标记了。

我们可以稍微优化一下,让 `j``i` 的平方开始遍历,而不是从 `2 * i` 开始:

```java
for (int j = i * i; j < n; j += i) 
    isPrim[j] = false;
```

这样,素数计数的算法就高效实现了,其实这个算法有一个名字,叫做 Sieve of Eratosthenes。看下完整的最终代码:

```java
int countPrimes(int n) {
    boolean[] isPrim = new boolean[n];
    Arrays.fill(isPrim, true);
    for (int i = 2; i * i < n; i++) 
        if (isPrim[i]) 
            for (int j = i * i; j < n; j += i) 
                isPrim[j] = false;
    
    int count = 0;
    for (int i = 2; i < n; i++)
        if (isPrim[i]) count++;
    
    return count;
}
```

**该算法的时间复杂度比较难算**,显然时间跟这两个嵌套的 for 循环有关,其操作数应该是:

  n/2 + n/3 + n/5 + n/7 + ...
= n × (1/2 + 1/3 + 1/5 + 1/7...)

括号中是素数的倒数。其最终结果是 O(N * loglogN),有兴趣的读者可以查一下该算法的时间复杂度证明。

以上就是素数算法相关的全部内容。怎么样,是不是看似简单的问题却有不少细节可以打磨呀?

**致力于把算法讲清楚!欢迎关注我的微信公众号 labuladong,查看更多通俗易懂的文章**

L
labuladong 已提交
151 152
![labuladong](../pictures/labuladong.png)

何睿 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
[ruicore](https://github.com/ruicore/algorithm) 提供 Python3 代码:

```py
class Solution:
    def countPrimes(self, n: int):
        # 前 2 个数不是素数
        if n < 3:
            return 0
        # isprime 数组
        isprime = [1] * n

        for i in range(2, int(n ** 0.5) + 1):
            if isprime[i]:
                # 从 i*i(包括自己) 到 n(不包括n),每步增加 i 的所有数的个数
                tmp = ((n - 1 - i * i) // i + 1)
                # 这种方式赋值比用 for 循环更高效
                isprime[i * i: n: i] = [0] * tmp

        # 前 2 个数不是素数,去掉
        count = sum(isprime[2:])
        return count
```

L
labuladong 已提交
176 177 178 179 180
[上一篇:如何实现LRU算法](../高频面试系列/LRU算法.md)

[下一篇:如何计算编辑距离](../动态规划系列/编辑距离.md)

[目录](../README.md#目录)