# 回溯算法团灭子集、排列、组合问题

GitHub

![](../pictures/souyisou.png) **labuladong 刷题辅助插件上线,欢迎大家使用,[下载地址](https://github.com/labuladong/fucking-algorithm/releases),别忘了点个 star**~ 读完本文,你不仅学会了算法套路,还可以顺便去 LeetCode 上拿下如下题目: [78.子集](https://leetcode-cn.com/problems/subsets) [46.全排列](https://leetcode-cn.com/problems/permutations) [77.组合](https://leetcode-cn.com/problems/combinations) **-----------** 今天就来聊三道考察频率高,而且容易让人搞混的算法问题,分别是求子集(subset),求排列(permutation),求组合(combination)。 这几个问题都可以用回溯算法模板解决,同时子集问题还可以用数学归纳思想解决。读者可以记住这几个问题的回溯套路,就不怕搞不清了。 ### 一、子集 问题很简单,输入一个**不包含重复数字**的数组,要求算法输出这些数字的所有子集。 ```cpp vector> subsets(vector& nums); ``` 比如输入 `nums = [1,2,3]`,你的算法应输出 8 个子集,包含空集和本身,顺序可以不同: [ [],[1],[2],[3],[1,3],[2,3],[1,2],[1,2,3] ] **第一个解法是利用数学归纳的思想**:假设我现在知道了规模更小的子问题的结果,如何推导出当前问题的结果呢? 具体来说就是,现在让你求 `[1,2,3]` 的子集,如果你知道了 `[1,2]` 的子集,是否可以推导出 `[1,2,3]` 的子集呢?先把 `[1,2]` 的子集写出来瞅瞅: [ [],[1],[2],[1,2] ] 你会发现这样一个规律: subset(`[1,2,3]`) - subset(`[1,2]`) = [3],[1,3],[2,3],[1,2,3] 而这个结果,就是把 sebset(`[1,2]`) 的结果中每个集合再添加上 3。 换句话说,如果 `A = subset([1,2])` ,那么: subset(`[1,2,3]`) = A + [A[i].add(3) for i = 1..len(A)] 这就是一个典型的递归结构嘛,`[1,2,3]` 的子集可以由 `[1,2]` 追加得出,`[1,2]` 的子集可以由 `[1]` 追加得出,base case 显然就是当输入集合为空集时,输出子集也就是一个空集。 翻译成代码就很容易理解了: ```cpp vector> subsets(vector& nums) { // base case,返回一个空集 if (nums.empty()) return {{}}; // 把最后一个元素拿出来 int n = nums.back(); nums.pop_back(); // 先递归算出前面元素的所有子集 vector> res = subsets(nums); int size = res.size(); for (int i = 0; i < size; i++) { // 然后在之前的结果之上追加 res.push_back(res[i]); res.back().push_back(n); } return res; } ``` **这个问题的时间复杂度计算比较容易坑人**。我们之前说的计算递归算法时间复杂度的方法,是找到递归深度,然后乘以每次递归中迭代的次数。对于这个问题,递归深度显然是 N,但我们发现每次递归 for 循环的迭代次数取决于 `res` 的长度,并不是固定的。 根据刚才的思路,`res` 的长度应该是每次递归都翻倍,所以说总的迭代次数应该是 2^N。或者不用这么麻烦,你想想一个大小为 N 的集合的子集总共有几个?2^N 个对吧,所以说至少要对 `res` 添加 2^N 次元素。 那么算法的时间复杂度就是 O(2^N) 吗?还是不对,2^N 个子集是 `push_back` 添加进 `res` 的,所以要考虑 `push_back` 这个操作的效率: ```cpp for (int i = 0; i < size; i++) { res.push_back(res[i]); // O(N) res.back().push_back(n); // O(1) } ``` 因为 `res[i]` 也是一个数组呀,`push_back` 是把 `res[i]` copy 一份然后添加到数组的最后,所以一次操作的时间是 O(N)。 综上,总的时间复杂度就是 O(N*2^N),还是比较耗时的。 空间复杂度的话,如果不计算储存返回结果所用的空间的,只需要 O(N) 的递归堆栈空间。如果计算 `res` 所需的空间,应该是 O(N*2^N)。 **第二种通用方法就是回溯算法**。旧文「回溯算法详解」写过回溯算法的模板: ```python result = [] def backtrack(路径, 选择列表): if 满足结束条件: result.add(路径) return for 选择 in 选择列表: 做选择 backtrack(路径, 选择列表) 撤销选择 ``` 只要改造回溯算法的模板就行了: ```cpp vector> res; vector> subsets(vector& nums) { // 记录走过的路径 vector track; backtrack(nums, 0, track); return res; } void backtrack(vector& nums, int start, vector& track) { res.push_back(track); for (int i = start; i < nums.size(); i++) { // 做选择 track.push_back(nums[i]); // 回溯 backtrack(nums, i + 1, track); // 撤销选择 track.pop_back(); } } ``` 可以看见,对 `res` 更新的位置处在前序遍历,也就是说,**`res` 就是树上的所有节点**: ![](../pictures/子集/1.jpg) ### 二、组合 输入两个数字 `n, k`,算法输出 `[1..n]` 中 k 个数字的所有组合。 ```cpp vector> combine(int n, int k); ``` 比如输入 `n = 4, k = 2`,输出如下结果,顺序无所谓,但是不能包含重复(按照组合的定义,`[1,2]` 和 `[2,1]` 也算重复): [ [1,2], [1,3], [1,4], [2,3], [2,4], [3,4] ] 这也是典型的回溯算法,`k` 限制了树的高度,`n` 限制了树的宽度,继续套我们以前讲过的回溯算法模板框架就行了: ![](../pictures/子集/2.jpg) ```cpp vector>res; vector> combine(int n, int k) { if (k <= 0 || n <= 0) return res; vector track; backtrack(n, k, 1, track); return res; } void backtrack(int n, int k, int start, vector& track) { // 到达树的底部 if (k == track.size()) { res.push_back(track); return; } // 注意 i 从 start 开始递增 for (int i = start; i <= n; i++) { // 做选择 track.push_back(i); backtrack(n, k, i + 1, track); // 撤销选择 track.pop_back(); } } ``` `backtrack` 函数和计算子集的差不多,区别在于,更新 `res` 的时机是树到达底端时。 ### 三、排列 输入一个**不包含重复数字**的数组 `nums`,返回这些数字的全部排列。 ```cpp vector> permute(vector& nums); ``` 比如说输入数组 `[1,2,3]`,输出结果应该如下,顺序无所谓,不能有重复: [ [1,2,3], [1,3,2], [2,1,3], [2,3,1], [3,1,2], [3,2,1] ] 「回溯算法详解」中就是拿这个问题来解释回溯模板的。这里又列出这个问题,是将「排列」和「组合」这两个回溯算法的代码拿出来对比。 首先画出回溯树来看一看: ![](../pictures/子集/3.jpg) 我们当时使用 Java 代码写的解法: ```java List> res = new LinkedList<>(); /* 主函数,输入一组不重复的数字,返回它们的全排列 */ List> permute(int[] nums) { // 记录「路径」 LinkedList track = new LinkedList<>(); backtrack(nums, track); return res; } void backtrack(int[] nums, LinkedList track) { // 触发结束条件 if (track.size() == nums.length) { res.add(new LinkedList(track)); return; } for (int i = 0; i < nums.length; i++) { // 排除不合法的选择 if (track.contains(nums[i])) continue; // 做选择 track.add(nums[i]); // 进入下一层决策树 backtrack(nums, track); // 取消选择 track.removeLast(); } } ``` 回溯模板依然没有变,但是根据排列问题和组合问题画出的树来看,排列问题的树比较对称,而组合问题的树越靠右节点越少。 在代码中的体现就是,排列问题每次通过 `contains` 方法来排除在 `track` 中已经选择过的数字;而组合问题通过传入一个 `start` 参数,来排除 `start` 索引之前的数字。 **以上,就是排列组合和子集三个问题的解法,总结一下**: 子集问题可以利用数学归纳思想,假设已知一个规模较小的问题的结果,思考如何推导出原问题的结果。也可以用回溯算法,要用 `start` 参数排除已选择的数字。 组合问题利用的是回溯思想,结果可以表示成树结构,我们只要套用回溯算法模板即可,关键点在于要用一个 `start` 排除已经选择过的数字。 排列问题是回溯思想,也可以表示成树结构套用算法模板,关键点在于使用 `contains` 方法排除已经选择的数字,前文有详细分析,这里主要是和组合问题作对比。 记住这几种树的形状,就足以应对大部分回溯算法问题了,无非就是 `start` 或者 `contains` 剪枝,也没啥别的技巧了。 **_____________** **刷算法,学套路,认准 labuladong,公众号和 [在线电子书](https://labuladong.gitee.io/algo/) 持续更新最新文章**。 **本小抄即将出版,微信扫码关注公众号,后台回复「小抄」限时免费获取,回复「进群」可进刷题群一起刷题,带你搞定 LeetCode**。

======其他语言代码====== [78.子集](https://leetcode-cn.com/problems/subsets) [46.全排列](https://leetcode-cn.com/problems/permutations) [77.组合](https://leetcode-cn.com/problems/combinations) ### java [userLF](https://github.com/userLF)提供全排列的java代码: ```java import java.util.ArrayList; import java.util.List; class Solution { List> res = new ArrayList<>(); public List> permute(int[] nums) { res.clear(); dfs(nums, 0);// return res; } public void dfs(int[] n, int start) {//start表示要被替换元素的位置 if( start >= n.length) { List list = new ArrayList(); for(int i : n) { list.add(i); } res.add(list); return; } for(int i = start; i< n.length; i++) {//i从start开始,如果从start+1开始的话,会把当前序列遗漏掉直接保存了下一个序列 int temp= n[i]; n[i] = n[start]; n[start] = temp; dfs(n, start + 1);//递归下一个位置 //回到上一个状态 n[start] = n[i]; n[i] = temp; } } } ``` ### javascript [传送门:78. 子集](https://leetcode-cn.com/problems/subsets/) 数学归纳思想 ```js /** * @param {number[]} nums * @return {number[][]} */ var subsets = function (nums) { // base case, 返回一个空集 if (nums.length === 0) { return [[]] } // 把最后一个元素拿出来 let n = nums.pop(); // 递归算出前面元素的所有子集 let res = subsets(nums); let size = res.length; for (let i = 0; i < size; i++) { // 然后在之前的结果之上追加 res.push(res[i]); res[res.length - 1].push(n); } return res; } ``` 回溯思想 ```js /** * @param {number[]} nums * @return {number[][]} */ const subsets = (nums) => { // 1. 设置结果集 const result = [[]]; // 2. 数组排序 nums.sort((a, b) => a - b); // 3. 递归 const recursion = (index, path) => { // 3.1 设置终止条件 if (path.length === nums.length) { return; } // 3.2 遍历数组 for (let i = index; i < nums.length; i++) { // 3.2.1 添加内容 path.push(nums[i]); // 3.2.2 添加结果集 result.push(path.concat()); // 3.2.3 进一步递归 recursion(i + 1, path); // 3.2.4 回溯,还原之前状态,以备下一次使用 path.pop(); } }; recursion(0, []); // 4. 返回结果 return result; }; console.log(subsets([1, 2, 3])); ``` [传送门:46.全排列](https://leetcode-cn.com/problems/permutations) 不得不说,用js实现递归总是能遇上很多坑,其中多半都是js内置函数和引用问题的坑。看完本文,你很容易就能写入如下的js解法,但结果放到leetcode一跑,输出却十分迷惑。 ```js let res = [] /** * @param {number[]} nums * @return {number[][]} */ var permute = function (nums) { let track = []; backtrack(nums, track); return res; }; var backtrack = function (nums, track) { // 触发结束条件 if (track.length === nums.length) { res.push(track) return; } for (let i = 0; i < nums.length; i++) { // 排除不合法的选择 if (track.indexOf(nums[i]) > -1) { continue; } // 做选择 track.push(nums[i]); // 进入下一层决策树 backtrack(nums, track); // 取消选择 track.pop() } } ``` ``` 输入:[1,2,3] 输出结果:[[],[],[],[],[],[]] 预期结果:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]] ``` 经过借鉴和改进后,无bug写法如下。 ```js /** * @param {number[]} nums * @return {number[][]} */ var permute = function(nums){ // 1. 设置结果集 const result = []; // 2. 回溯 const recursion = (path, set) => { // 2.1 设置回溯终止条件 if (path.length === nums.length) { // 2.1.1 推入结果集 result.push(path.concat()); // 2.1.2 终止递归 return; } // 2.2 遍历数组 for (let i = 0; i < nums.length; i++) { // 2.2.1 必须是不存在 set 中的坐标 if (!set.has(i)) { // 2.2.2 本地递归条件(用完记得删除) path.push(nums[i]); set.add(i); // 2.2.3 进一步递归 recursion(path, set); // 2.2.4 回溯:撤回 2.2.2 的操作 path.pop(); set.delete(i); } } }; recursion([], new Set()); // 3. 返回结果 return result; }; ``` 看到这,才恍然大悟,原来是在 2.1.1 推入结果集的过程中,需要使用concat来实现数组的浅复制,并加入到result结果集中,不然会因为引用问题而导致结果集中都是空list。除此之外,你还要注意格外let块级作用域,在作用域外是找不到的! [传送门:77.组合](https://leetcode-cn.com/problems/combinations) ```js var combine = function (n, k) { const res = [] if (k <= 0 || n <= 0) return res; const track = []; const backtrack = (n, k, start, track) => { // 到达树的底部 if (k === track.length) { res.push(track.concat()); return; } // 注意i从start开始递增 for (let i = start; i <= n; i++) { // 做选择 track.push(i); backtrack(n, k, i + 1, track); // 撤销选择 track.pop(); } } backtrack(n, k, 1, track); return res; }; ```