Python常用算法——回溯算法
回溯算法思想与经典案例解析
一、回溯算法概述
回溯算法(Backtracking)是一种通过暴力穷举方式解决问题的系统搜索方法,它是深度优先搜索的一种具体应用。回溯算法的核心思想是:走不通就退回重走,在包含问题所有解的路径树中,按照深度优先策略,从根节点出发搜索解空间树。
1.1 回溯算法的基本特征
回溯算法具有以下特征:
- 系统性:按深度优先策略系统搜索解空间
- 跳跃性:判断节点不包含解时,跳过该子树回溯
- 递归实现:通常用最简单的递归方法实现
1.2 回溯算法的步骤
使用回溯算法的一般步骤如下:
- 定义解空间:子集树或排列树
- 组织解空间:用适合搜索的方法组织
- 深度优先搜索:按深度优先策略搜索
- 剪枝优化:利用剪枝函数避免无效搜索
1.3 回溯算法的三要素
- 解空间:要解决问题的范围
- 约束条件:包括显性和隐性的限制
- 状态树:搜索过程展开的依据
二、经典回溯算法案例
下面通过15个经典案例,深入理解回溯算法的应用。
2.1 八皇后问题
问题描述:在8×8的国际象棋棋盘上摆放8个皇后,使其不能互相攻击(任意两个皇后不在同一行、同一列、同一对角线上)。
代码实现:
n = 8
x = [] # 一个解(n元数组)
X = [] # 一组解
# 冲突检测
def conflict(k):
global x
for i in range(k):
if x[i] == x[k] or abs(x[i] - x[k]) == abs(i - k):
return True
return False
# 回溯求解
def queens(k):
global n, x, X
if k >= n:
X.append(x[:])
else:
for i in range(n):
x.append(i)
if not conflict(k):
queens(k + 1)
x.pop()
# 可视化
def show(x):
for i in range(n):
print('. ' * (x[i]) + 'X ' + '. ' * (n - x[i] - 1))
queens(0)
show(X[0])
核心要点:
- 用x[i]表示第i行皇后所在的列
- 冲突条件:同列(x[i]==x[k])或对角线(abs(x[i]-x[k])==abs(i-k))
- 递归逐行放置皇后
2.2 0-1背包问题
问题描述:给定n个物品,每个物品有重量w和价值v,在背包容量c的限制下,如何选择物品使总价值最大。
代码实现:
n = 3 # 物品数量
c = 30 # 背包容量
w = [20, 15, 15] # 物品重量
v = [45, 25, 25] # 物品价值
bag = [0, 0, 0] # 一个解(n元0-1数组)
maxv = 0 # 最大价值
bestbag = None # 最佳解
# 冲突检测
def conflict(k):
if sum([w[i] for i in range(k + 1) if bag[i] == 1]) > c:
return True
return False
# 回溯求解
def backpack(k):
global bag, maxv, bestbag
if k == n:
cv = sum([v[i] for i in range(n) if bag[i] == 1])
if cv > maxv:
maxv = cv
bestbag = bag[:]
else:
for i in [1, 0]: # 选或不选
bag[k] = i
if not conflict(k):
backpack(k + 1)
backpack(0)
print(bestbag, maxv)
核心要点:
- 子集树模板:每个物品有选或不选两种状态
- 剪枝条件:当前重量超过背包容量
2.3 旅行商问题(TSP)
问题描述:给定n个城市及城市间的距离,找一条经过所有城市一次且回到起点的最短路径。
代码实现:
n = 5 # 节点数
graph = [
{1: 7, 2: 6, 3: 1, 4: 3},
{0: 7, 2: 3, 3: 7, 4: 8},
{0: 6, 1: 3, 3: 12, 4: 11},
{0: 1, 1: 7, 2: 12, 4: 2},
{0: 3, 1: 8, 2: 11, 3: 2}
]
x = [0] * (n + 1) # 一个解(n+1元数组,最后回到起点)
min_cost = 0
best_x = None
# 冲突检测
def conflict(k):
if k < n and x[k] in x[:k]:
return True
if k == n and x[k] != x[0]:
return True
cost = sum([graph[x[i]][x[i + 1]] for i in range(k)])
if 0 < min_cost < cost:
return True
return False
# 回溯求解
def tsp(k):
global min_cost, best_x
if k > n:
cost = sum([graph[x[i]][x[i + 1]] for i in range(n)])
if min_cost == 0 or cost < min_cost:
min_cost = cost
best_x = x[:]
else:
for node in graph[x[k - 1]]:
x[k] = node
if not conflict(k):
tsp(k + 1)
x[0] = 2 # 从节点c出发
tsp(1)
print(best_x, min_cost)
核心要点:
- 排列树模板:每个城市只访问一次
- 剪枝条件:当前路径长度已超过已知最优解
2.4 图的m着色问题
问题描述:给定无向连通图和m种颜色,给每个顶点着色,使相邻顶点颜色不同。
代码实现:
n = 5
graph = [
{1, 2, 3}, # 节点0的邻接点
{0, 2, 3, 4}, # 节点1的邻接点
{0, 1, 3}, # 节点2的邻接点
{0, 1, 2, 4}, # 节点3的邻接点
{1, 3} # 节点4的邻接点
]
m = 4 # 4种颜色
x = [0] * n # 一个解(每个节点的颜色编号)
# 冲突检测
def conflict(k):
nodes = [node for node in range(k) if node in graph[k]]
if x[k] in [x[node] for node in nodes]:
return True
return False
# 回溯求解
def dfs(k):
if k == n:
print(x)
else:
for color in range(m):
x[k] = color
if not conflict(k):
dfs(k + 1)
dfs(0)
核心要点:
- 每个节点有m种颜色可选
- 冲突条件:相邻节点颜色相同
2.5 迷宫问题
问题描述:在给定的迷宫中,找一条从入口到出口的路径。
代码实现:
maze = [[1,1,1,1,1,1,1,1,1,1],
[0,0,1,0,1,1,1,1,0,1],
[1,1,0,1,0,1,1,0,1,1],
[1,0,1,1,1,0,0,1,1,1],
[1,1,1,0,0,1,1,0,1,1],
[1,1,0,1,1,1,1,1,0,1],
[1,0,1,0,0,1,1,1,1,0],
[1,1,1,1,1,0,1,1,1,1]]
entry = (1, 0) # 入口
path = [entry]
paths = []
# 8个移动方向
directions = [(-1,0), (-1,1), (0,1), (1,1), (1,0), (1,-1), (0,-1), (-1,-1)]
# 冲突检测
def conflict(nx, ny):
if 0 <= nx < 8 and 0 <= ny < 10 and maze[nx][ny] == 0:
return False
return True
# 回溯求解
def walk(x, y):
if (x, y) != entry and (x % 7 == 0 or y % 9 == 0):
paths.append(path[:])
else:
for d in directions:
nx, ny = x + d[0], y + d[1]
path.append((nx, ny))
if not conflict(nx, ny):
maze[nx][ny] = 2
walk(nx, ny)
maze[nx][ny] = 0
path.pop()
walk(1, 0)
print(paths[-1])
核心要点:
- 每个位置有8个方向可探索
- 用标记法避免重复访问
2.6 骑士巡游问题
问题描述:在国际象棋棋盘上,让骑士走遍所有格子且不重复。
代码实现:
SIZE = 5
total = 0
def print_board(board):
for row in board:
for col in row:
print(str(col).center(4), end='')
print()
def patrol(board, row, col, step=1):
if row >= 0 and row < SIZE and col >= 0 and col < SIZE and board[row][col] == 0:
board[row][col] = step
if step == SIZE * SIZE:
global total
total += 1
print(f'第{total}种走法: ')
print_board(board)
# 8个可能的方向
patrol(board, row - 2, col - 1, step + 1)
patrol(board, row - 1, col - 2, step + 1)
patrol(board, row + 1, col - 2, step + 1)
patrol(board, row + 2, col - 1, step + 1)
patrol(board, row + 2, col + 1, step + 1)
patrol(board, row + 1, col + 2, step + 1)
patrol(board, row - 1, col + 2, step + 1)
patrol(board, row - 2, col + 1, step + 1)
board[row][col] = 0
board = [[0] * SIZE for _ in range(SIZE)]
patrol(board, SIZE - 1, SIZE - 1)
核心要点:
- 骑士有8种走法
- 用步数标记访问顺序
2.7 全排列问题
问题描述:求给定集合的所有排列。
代码实现:
n = 4
a = ['a', 'b', 'c', 'd']
x = [0] * n
def perm(k):
if k >= n:
print(x)
else:
for i in set(a) - set(x[:k]):
x[k] = i
perm(k + 1)
perm(0)
另一种实现(交换法):
x = [1, 2, 3, 4]
def backkrak(k):
if k >= n:
print(x)
else:
for i in range(k, n):
x[k], x[i] = x[i], x[k]
backkrak(k + 1)
x[i], x[k] = x[k], x[i]
backkrak(0)
核心要点:
- 排列树模板
- 交换法更高效
2.8 组合问题
问题描述:从n个元素中选r个的所有组合。
代码实现:
n, r = 5, 3
a = [1, 2, 3, 4, 5]
x = [0] * n # 0-1数组表示选或不选
def conflict(k):
if sum(x[:k + 1]) > r:
return True
if sum(x[:k + 1]) + (n - k - 1) < r:
return True
return False
def comb(k):
if k >= n:
if sum(x) == r:
result = [a[i] for i in range(n) if x[i] == 1]
print(result)
else:
for i in [1, 0]:
x[k] = i
if not conflict(k):
comb(k + 1)
comb(0)
核心要点:
- 子集树模板
- 剪枝:已选数量超过r或剩余不够r
2.9 选排问题(可重复排列)
问题描述:从n个元素中选m个排列,每个元素最多可重复r次。
代码实现:
n, m, r = 4, 3, 2
a = ['a', 'b', 'c', 'd']
x = [0] * m
def conflict(k):
if x[:k + 1].count(x[k]) > r:
return True
return False
def perm(k):
if k == m:
print(x)
else:
for i in a:
x[k] = i
if not conflict(k):
perm(k + 1)
perm(0)
核心要点:
- 允许重复但有次数限制
- 剪枝:某元素出现次数超限
2.10 最长公共子序列
问题描述:求两个字符串的最长公共子序列。
代码实现:
a, b = 'belong', 'cnblogs'
x = [] # 存储b中字符的索引
best_x = []
best_len = 0
def conflict(k):
if x[-1] < len(b) and a[k] != b[x[-1]]:
return True
if a[k] == b[x[-1]] and (len(x) >= 2 and x[-1] <= x[-2]):
return True
if len(x) + (len(a) - k) < best_len:
return True
return False
def LCS(k):
global best_len, best_x
if k == len(a):
if len(x) > best_len:
best_len = len(x)
best_x = x[:]
else:
for i in range(len(b) + 1):
if i == len(b):
LCS(k + 1)
else:
x.append(i)
if not conflict(k):
LCS(k + 1)
x.pop()
LCS(0)
result = ''.join([b[i] for i in best_x])
print(result)
核心要点:
- 用索引表示匹配位置
- 剪枝:剩余长度不足最优解
2.11 硬币找零问题
问题描述:给定不同面额的硬币及数量,凑成指定金额的最少硬币数。
代码实现:
n = 4
a = [10, 5, 2, 1] # 面额
b = [3, 5, 7, 12] # 数量
m = 53 # 目标金额
x = [0] * n
best_x = []
best_num = 0
def conflict(k):
if sum([a[i] * x[i] for i in range(k + 1)]) > m:
return True
if sum([a[i] * x[i] for i in range(k + 1)]) + \
sum([a[i] * b[i] for i in range(k + 1, n)]) < m:
return True
num = sum(x[:k + 1])
if 0 < best_num < num:
return True
return False
def subsets(k):
global best_num, best_x
if k == n:
if sum([a[i] * x[i] for i in range(n)]) == m:
num = sum(x)
if best_num == 0 or num < best_num:
best_num = num
best_x = x[:]
else:
for i in range(b[k] + 1):
x[k] = i
if not conflict(k):
subsets(k + 1)
subsets(0)
print(best_x)
核心要点:
- 每个面额有数量限制
- 剪枝:金额超限、金额不足、硬币数超优
2.12 作业调度问题
问题描述:n个作业在两台机器上加工,每个作业先机器1后机器2,求最优调度使总时间最小。
代码实现:
n = 3
t = [[2, 1], [3, 1], [2, 3]] # 每个作业在两台机器上的时间
x = [0] * n
best_x = []
best_t = 0
def conflict(k):
if x[:k + 1].count(x[k]) > 1:
return True
j2_t = []
s = 0
for i in range(k + 1):
s += t[x[i]][0]
j2_t.append(s + t[x[i]][1])
total_t = sum(j2_t)
if total_t > best_t > 0:
return True
return False
def dispatch(k):
global best_t, best_x
if k == n:
j2_t = []
s = 0
for i in range(n):
s += t[x[i]][0]
j2_t.append(s + t[x[i]][1])
total_t = sum(j2_t)
if best_t == 0 or total_t < best_t:
best_t = total_t
best_x = x[:]
else:
for i in range(n):
x[k] = i
if not conflict(k):
dispatch(k + 1)
dispatch(0)
print(best_x, best_t)
核心要点:
- 排列树模板(作业顺序)
- 计算机器2的完成时间
2.13 机器人运动范围
问题描述:机器人从(0,0)出发,可以上下左右移动,但不能进入行坐标和列坐标的数位之和大于threshold的格子,求能到达的格子数。
代码实现:
class Solution:
def movingCount(self, threshold, rows, cols):
board = [[0 for i in range(cols)] for j in range(rows)]
global acc
acc = 0
def block(r, c):
s = sum(map(int, str(r) + str(c)))
return s > threshold
def traverse(r, c):
global acc
if not (0 <= r < rows and 0 <= c < cols):
return
if board[r][c] != 0:
return
if block(r, c):
board[r][c] = -1
return
board[r][c] = 1
acc += 1
traverse(r + 1, c)
traverse(r - 1, c)
traverse(r, c + 1)
traverse(r, c - 1)
traverse(0, 0)
return acc
2.14 矩阵中的路径
问题描述:判断矩阵中是否存在一条包含某字符串所有字符的路径。
代码实现:
class MatrixPath:
def hasPath(self, matrix, rows, cols, path):
for i in range(rows):
for j in range(cols):
if matrix[i * cols + j] == path[0]:
if self.findPath(list(matrix), rows, cols, path[1:], i, j):
return True
return False
def findPath(self, matrix, rows, cols, path, x, y):
if not path:
return True
matrix[x * cols + y] = '*'
# 四个方向探索
if y + 1 < cols and matrix[x * cols + y + 1] == path[0]:
return self.findPath(matrix, rows, cols, path[1:], x, y + 1)
elif y - 1 >= 0 and matrix[x * cols + y - 1] == path[0]:
return self.findPath(matrix, rows, cols, path[1:], x, y - 1)
elif x + 1 < rows and matrix[(x + 1) * cols + y] == path[0]:
return self.findPath(matrix, rows, cols, path[1:], x + 1, y)
elif x - 1 >= 0 and matrix[(x - 1) * cols + y] == path[0]:
return self.findPath(matrix, rows, cols, path[1:], x - 1, y)
else:
return False
2.15 传教士与野人问题
问题描述:n个传教士和n个野人渡河,船最多载m人,任何时刻两岸传教士人数不能少于野人(除非无人),求安全渡河方案。
代码实现:
n, m = 3, 2
x = [] # 船的状态
is_found = False
def get_states(k):
if k % 2 == 0: # 从左到右
s1, s2 = n - sum(s[0] for s in x), n - sum(s[1] for s in x)
else: # 从右到左
s1, s2 = sum(s[0] for s in x), sum(s[1] for s in x)
for i in range(s1 + 1):
for j in range(s2 + 1):
if 0 < i + j <= m and (i * j == 0 or i >= j):
yield [(-i, -j), (i, j)][k % 2 == 0]
def conflict(k):
if k > 0 and x[-1][0] == -x[-2][0] and x[-1][1] == -x[-2][1]:
return True
if 0 < n - sum(s[0] for s in x) < n - sum(s[1] for s in x):
return True
if 0 < sum(s[0] for s in x) < sum(s[1] for s in x):
return True
return False
def backtrack(k):
global is_found
if is_found:
return
if n - sum(s[0] for s in x) == 0 and n - sum(s[1] for s in x) == 0:
print(x)
is_found = True
else:
for state in get_states(k):
x.append(state)
if not conflict(k):
backtrack(k + 1)
x.pop()
backtrack(0)
三、回溯算法模板总结
3.1 子集树模板
适用于0-1背包、组合等问题:
def backtrack(k):
if k >= n: # 到达叶子节点
process_solution()
else:
for i in [0, 1]: # 两种状态
x.append(i)
if not conflict(k):
backtrack(k + 1)
x.pop()
3.2 排列树模板
适用于全排列、TSP、作业调度等问题:
def backtrack(k):
if k >= n:
process_solution()
else:
for i in range(k, n):
x[k], x[i] = x[i], x[k]
if not conflict(k):
backtrack(k + 1)
x[i], x[k] = x[k], x[i]
3.3 剪枝策略
提高回溯效率的关键:
- 约束剪枝:不满足约束条件时剪枝
- 限界剪枝:当前解不可能优于已知最优解时剪枝
- 对称剪枝:利用问题的对称性去重
- 提前判断:剩余资源不足以完成解时剪枝
四、回溯算法的应用场景
回溯算法适用于以下类型的问题:
- 组合问题:从集合中选取元素的组合
- 排列问题:元素的全排列
- 子集问题:集合的所有子集
- 棋盘问题:N皇后、骑士巡游
- 图论问题:图的着色、哈密顿路径
- 路径搜索:迷宫、机器人运动
- 调度问题:作业调度、课程安排
- 资源分配:背包、硬币找零
五、回溯算法的优化
- 搜索顺序优化:分支少的节点优先搜索
- 约束传播:提前排除不可能的解
- 对称性剪枝:避免重复解
- 记忆化搜索:存储中间结果避免重复计算
- 启发式搜索:优先搜索可能产生解的分支
六、总结
回溯算法是一种通用的问题求解方法,通过深度优先搜索系统地遍历解空间,利用剪枝函数避免无效搜索。它的优点是能够找到所有解或最优解,缺点是当问题规模较大时可能面临组合爆炸。在实际应用中,需要根据问题特点设计合适的解空间表示、剪枝策略和搜索顺序,才能高效地解决问题。
通过以上15个经典案例,我们可以看到回溯算法的强大表达能力和广泛应用领域。掌握回溯算法的模板和技巧,对于解决算法竞赛和实际工作中的复杂问题都有很大帮助。