回溯算法思想与经典案例解析

一、回溯算法概述

回溯算法(Backtracking)是一种通过暴力穷举方式解决问题的系统搜索方法,它是深度优先搜索的一种具体应用。回溯算法的核心思想是:走不通就退回重走,在包含问题所有解的路径树中,按照深度优先策略,从根节点出发搜索解空间树。

1.1 回溯算法的基本特征

回溯算法具有以下特征:

  • 系统性:按深度优先策略系统搜索解空间
  • 跳跃性:判断节点不包含解时,跳过该子树回溯
  • 递归实现:通常用最简单的递归方法实现

1.2 回溯算法的步骤

使用回溯算法的一般步骤如下:

  1. 定义解空间:子集树或排列树
  2. 组织解空间:用适合搜索的方法组织
  3. 深度优先搜索:按深度优先策略搜索
  4. 剪枝优化:利用剪枝函数避免无效搜索

1.3 回溯算法的三要素

  • 解空间:要解决问题的范围
  • 约束条件:包括显性和隐性的限制
  • 状态树:搜索过程展开的依据

二、经典回溯算法案例

下面通过15个经典案例,深入理解回溯算法的应用。

2.1 八皇后问题

问题描述:在8×8的国际象棋棋盘上摆放8个皇后,使其不能互相攻击(任意两个皇后不在同一行、同一列、同一对角线上)。

代码实现

n = 8
x = []  # 一个解(n元数组)
X = []  # 一组解

# 冲突检测
def conflict(k):
    global x
    for i in range(k):
        if x[i] == x[k] or abs(x[i] - x[k]) == abs(i - k):
            return True
    return False

# 回溯求解
def queens(k):
    global n, x, X
    if k >= n:
        X.append(x[:])
    else:
        for i in range(n):
            x.append(i)
            if not conflict(k):
                queens(k + 1)
            x.pop()

# 可视化
def show(x):
    for i in range(n):
        print('. ' * (x[i]) + 'X ' + '. ' * (n - x[i] - 1))

queens(0)
show(X[0])

核心要点

  • 用x[i]表示第i行皇后所在的列
  • 冲突条件:同列(x[i]==x[k])或对角线(abs(x[i]-x[k])==abs(i-k))
  • 递归逐行放置皇后

2.2 0-1背包问题

问题描述:给定n个物品,每个物品有重量w和价值v,在背包容量c的限制下,如何选择物品使总价值最大。

代码实现

n = 3  # 物品数量
c = 30  # 背包容量
w = [20, 15, 15]  # 物品重量
v = [45, 25, 25]  # 物品价值

bag = [0, 0, 0]  # 一个解(n元0-1数组)
maxv = 0  # 最大价值
bestbag = None  # 最佳解

# 冲突检测
def conflict(k):
    if sum([w[i] for i in range(k + 1) if bag[i] == 1]) > c:
        return True
    return False

# 回溯求解
def backpack(k):
    global bag, maxv, bestbag
    if k == n:
        cv = sum([v[i] for i in range(n) if bag[i] == 1])
        if cv > maxv:
            maxv = cv
            bestbag = bag[:]
    else:
        for i in [1, 0]:  # 选或不选
            bag[k] = i
            if not conflict(k):
                backpack(k + 1)

backpack(0)
print(bestbag, maxv)

核心要点

  • 子集树模板:每个物品有选或不选两种状态
  • 剪枝条件:当前重量超过背包容量

2.3 旅行商问题(TSP)

问题描述:给定n个城市及城市间的距离,找一条经过所有城市一次且回到起点的最短路径。

代码实现

n = 5  # 节点数
graph = [
    {1: 7, 2: 6, 3: 1, 4: 3},
    {0: 7, 2: 3, 3: 7, 4: 8},
    {0: 6, 1: 3, 3: 12, 4: 11},
    {0: 1, 1: 7, 2: 12, 4: 2},
    {0: 3, 1: 8, 2: 11, 3: 2}
]

x = [0] * (n + 1)  # 一个解(n+1元数组,最后回到起点)
min_cost = 0
best_x = None

# 冲突检测
def conflict(k):
    if k < n and x[k] in x[:k]:
        return True
    if k == n and x[k] != x[0]:
        return True
    cost = sum([graph[x[i]][x[i + 1]] for i in range(k)])
    if 0 < min_cost < cost:
        return True
    return False

# 回溯求解
def tsp(k):
    global min_cost, best_x
    if k > n:
        cost = sum([graph[x[i]][x[i + 1]] for i in range(n)])
        if min_cost == 0 or cost < min_cost:
            min_cost = cost
            best_x = x[:]
    else:
        for node in graph[x[k - 1]]:
            x[k] = node
            if not conflict(k):
                tsp(k + 1)

x[0] = 2  # 从节点c出发
tsp(1)
print(best_x, min_cost)

核心要点

  • 排列树模板:每个城市只访问一次
  • 剪枝条件:当前路径长度已超过已知最优解

2.4 图的m着色问题

问题描述:给定无向连通图和m种颜色,给每个顶点着色,使相邻顶点颜色不同。

代码实现

n = 5
graph = [
    {1, 2, 3},     # 节点0的邻接点
    {0, 2, 3, 4},  # 节点1的邻接点
    {0, 1, 3},     # 节点2的邻接点
    {0, 1, 2, 4},  # 节点3的邻接点
    {1, 3}         # 节点4的邻接点
]
m = 4  # 4种颜色

x = [0] * n  # 一个解(每个节点的颜色编号)

# 冲突检测
def conflict(k):
    nodes = [node for node in range(k) if node in graph[k]]
    if x[k] in [x[node] for node in nodes]:
        return True
    return False

# 回溯求解
def dfs(k):
    if k == n:
        print(x)
    else:
        for color in range(m):
            x[k] = color
            if not conflict(k):
                dfs(k + 1)

dfs(0)

核心要点

  • 每个节点有m种颜色可选
  • 冲突条件:相邻节点颜色相同

2.5 迷宫问题

问题描述:在给定的迷宫中,找一条从入口到出口的路径。

代码实现

maze = [[1,1,1,1,1,1,1,1,1,1],
        [0,0,1,0,1,1,1,1,0,1],
        [1,1,0,1,0,1,1,0,1,1],
        [1,0,1,1,1,0,0,1,1,1],
        [1,1,1,0,0,1,1,0,1,1],
        [1,1,0,1,1,1,1,1,0,1],
        [1,0,1,0,0,1,1,1,1,0],
        [1,1,1,1,1,0,1,1,1,1]]

entry = (1, 0)  # 入口
path = [entry]
paths = []

# 8个移动方向
directions = [(-1,0), (-1,1), (0,1), (1,1), (1,0), (1,-1), (0,-1), (-1,-1)]

# 冲突检测
def conflict(nx, ny):
    if 0 <= nx < 8 and 0 <= ny < 10 and maze[nx][ny] == 0:
        return False
    return True

# 回溯求解
def walk(x, y):
    if (x, y) != entry and (x % 7 == 0 or y % 9 == 0):
        paths.append(path[:])
    else:
        for d in directions:
            nx, ny = x + d[0], y + d[1]
            path.append((nx, ny))
            if not conflict(nx, ny):
                maze[nx][ny] = 2
                walk(nx, ny)
                maze[nx][ny] = 0
            path.pop()

walk(1, 0)
print(paths[-1])

核心要点

  • 每个位置有8个方向可探索
  • 用标记法避免重复访问

2.6 骑士巡游问题

问题描述:在国际象棋棋盘上,让骑士走遍所有格子且不重复。

代码实现

SIZE = 5
total = 0

def print_board(board):
    for row in board:
        for col in row:
            print(str(col).center(4), end='')
        print()

def patrol(board, row, col, step=1):
    if row >= 0 and row < SIZE and col >= 0 and col < SIZE and board[row][col] == 0:
        board[row][col] = step
        if step == SIZE * SIZE:
            global total
            total += 1
            print(f'第{total}种走法: ')
            print_board(board)
        # 8个可能的方向
        patrol(board, row - 2, col - 1, step + 1)
        patrol(board, row - 1, col - 2, step + 1)
        patrol(board, row + 1, col - 2, step + 1)
        patrol(board, row + 2, col - 1, step + 1)
        patrol(board, row + 2, col + 1, step + 1)
        patrol(board, row + 1, col + 2, step + 1)
        patrol(board, row - 1, col + 2, step + 1)
        patrol(board, row - 2, col + 1, step + 1)
        board[row][col] = 0

board = [[0] * SIZE for _ in range(SIZE)]
patrol(board, SIZE - 1, SIZE - 1)

核心要点

  • 骑士有8种走法
  • 用步数标记访问顺序

2.7 全排列问题

问题描述:求给定集合的所有排列。

代码实现

n = 4
a = ['a', 'b', 'c', 'd']
x = [0] * n

def perm(k):
    if k >= n:
        print(x)
    else:
        for i in set(a) - set(x[:k]):
            x[k] = i
            perm(k + 1)

perm(0)

另一种实现(交换法)

x = [1, 2, 3, 4]

def backkrak(k):
    if k >= n:
        print(x)
    else:
        for i in range(k, n):
            x[k], x[i] = x[i], x[k]
            backkrak(k + 1)
            x[i], x[k] = x[k], x[i]

backkrak(0)

核心要点

  • 排列树模板
  • 交换法更高效

2.8 组合问题

问题描述:从n个元素中选r个的所有组合。

代码实现

n, r = 5, 3
a = [1, 2, 3, 4, 5]
x = [0] * n  # 0-1数组表示选或不选

def conflict(k):
    if sum(x[:k + 1]) > r:
        return True
    if sum(x[:k + 1]) + (n - k - 1) < r:
        return True
    return False

def comb(k):
    if k >= n:
        if sum(x) == r:
            result = [a[i] for i in range(n) if x[i] == 1]
            print(result)
    else:
        for i in [1, 0]:
            x[k] = i
            if not conflict(k):
                comb(k + 1)

comb(0)

核心要点

  • 子集树模板
  • 剪枝:已选数量超过r或剩余不够r

2.9 选排问题(可重复排列)

问题描述:从n个元素中选m个排列,每个元素最多可重复r次。

代码实现

n, m, r = 4, 3, 2
a = ['a', 'b', 'c', 'd']
x = [0] * m

def conflict(k):
    if x[:k + 1].count(x[k]) > r:
        return True
    return False

def perm(k):
    if k == m:
        print(x)
    else:
        for i in a:
            x[k] = i
            if not conflict(k):
                perm(k + 1)

perm(0)

核心要点

  • 允许重复但有次数限制
  • 剪枝:某元素出现次数超限

2.10 最长公共子序列

问题描述:求两个字符串的最长公共子序列。

代码实现

a, b = 'belong', 'cnblogs'
x = []  # 存储b中字符的索引
best_x = []
best_len = 0

def conflict(k):
    if x[-1] < len(b) and a[k] != b[x[-1]]:
        return True
    if a[k] == b[x[-1]] and (len(x) >= 2 and x[-1] <= x[-2]):
        return True
    if len(x) + (len(a) - k) < best_len:
        return True
    return False

def LCS(k):
    global best_len, best_x
    if k == len(a):
        if len(x) > best_len:
            best_len = len(x)
            best_x = x[:]
    else:
        for i in range(len(b) + 1):
            if i == len(b):
                LCS(k + 1)
            else:
                x.append(i)
                if not conflict(k):
                    LCS(k + 1)
                x.pop()

LCS(0)
result = ''.join([b[i] for i in best_x])
print(result)

核心要点

  • 用索引表示匹配位置
  • 剪枝:剩余长度不足最优解

2.11 硬币找零问题

问题描述:给定不同面额的硬币及数量,凑成指定金额的最少硬币数。

代码实现

n = 4
a = [10, 5, 2, 1]  # 面额
b = [3, 5, 7, 12]  # 数量
m = 53  # 目标金额

x = [0] * n
best_x = []
best_num = 0

def conflict(k):
    if sum([a[i] * x[i] for i in range(k + 1)]) > m:
        return True
    if sum([a[i] * x[i] for i in range(k + 1)]) + \
       sum([a[i] * b[i] for i in range(k + 1, n)]) < m:
        return True
    num = sum(x[:k + 1])
    if 0 < best_num < num:
        return True
    return False

def subsets(k):
    global best_num, best_x
    if k == n:
        if sum([a[i] * x[i] for i in range(n)]) == m:
            num = sum(x)
            if best_num == 0 or num < best_num:
                best_num = num
                best_x = x[:]
    else:
        for i in range(b[k] + 1):
            x[k] = i
            if not conflict(k):
                subsets(k + 1)

subsets(0)
print(best_x)

核心要点

  • 每个面额有数量限制
  • 剪枝:金额超限、金额不足、硬币数超优

2.12 作业调度问题

问题描述:n个作业在两台机器上加工,每个作业先机器1后机器2,求最优调度使总时间最小。

代码实现

n = 3
t = [[2, 1], [3, 1], [2, 3]]  # 每个作业在两台机器上的时间
x = [0] * n
best_x = []
best_t = 0

def conflict(k):
    if x[:k + 1].count(x[k]) > 1:
        return True
    
    j2_t = []
    s = 0
    for i in range(k + 1):
        s += t[x[i]][0]
        j2_t.append(s + t[x[i]][1])
    total_t = sum(j2_t)
    if total_t > best_t > 0:
        return True
    return False

def dispatch(k):
    global best_t, best_x
    if k == n:
        j2_t = []
        s = 0
        for i in range(n):
            s += t[x[i]][0]
            j2_t.append(s + t[x[i]][1])
        total_t = sum(j2_t)
        if best_t == 0 or total_t < best_t:
            best_t = total_t
            best_x = x[:]
    else:
        for i in range(n):
            x[k] = i
            if not conflict(k):
                dispatch(k + 1)

dispatch(0)
print(best_x, best_t)

核心要点

  • 排列树模板(作业顺序)
  • 计算机器2的完成时间

2.13 机器人运动范围

问题描述:机器人从(0,0)出发,可以上下左右移动,但不能进入行坐标和列坐标的数位之和大于threshold的格子,求能到达的格子数。

代码实现

class Solution:
    def movingCount(self, threshold, rows, cols):
        board = [[0 for i in range(cols)] for j in range(rows)]
        global acc
        acc = 0
        
        def block(r, c):
            s = sum(map(int, str(r) + str(c)))
            return s > threshold
        
        def traverse(r, c):
            global acc
            if not (0 <= r < rows and 0 <= c < cols):
                return
            if board[r][c] != 0:
                return
            if block(r, c):
                board[r][c] = -1
                return
            
            board[r][c] = 1
            acc += 1
            traverse(r + 1, c)
            traverse(r - 1, c)
            traverse(r, c + 1)
            traverse(r, c - 1)
        
        traverse(0, 0)
        return acc

2.14 矩阵中的路径

问题描述:判断矩阵中是否存在一条包含某字符串所有字符的路径。

代码实现

class MatrixPath:
    def hasPath(self, matrix, rows, cols, path):
        for i in range(rows):
            for j in range(cols):
                if matrix[i * cols + j] == path[0]:
                    if self.findPath(list(matrix), rows, cols, path[1:], i, j):
                        return True
        return False
    
    def findPath(self, matrix, rows, cols, path, x, y):
        if not path:
            return True
        matrix[x * cols + y] = '*'
        
        # 四个方向探索
        if y + 1 < cols and matrix[x * cols + y + 1] == path[0]:
            return self.findPath(matrix, rows, cols, path[1:], x, y + 1)
        elif y - 1 >= 0 and matrix[x * cols + y - 1] == path[0]:
            return self.findPath(matrix, rows, cols, path[1:], x, y - 1)
        elif x + 1 < rows and matrix[(x + 1) * cols + y] == path[0]:
            return self.findPath(matrix, rows, cols, path[1:], x + 1, y)
        elif x - 1 >= 0 and matrix[(x - 1) * cols + y] == path[0]:
            return self.findPath(matrix, rows, cols, path[1:], x - 1, y)
        else:
            return False

2.15 传教士与野人问题

问题描述:n个传教士和n个野人渡河,船最多载m人,任何时刻两岸传教士人数不能少于野人(除非无人),求安全渡河方案。

代码实现

n, m = 3, 2
x = []  # 船的状态
is_found = False

def get_states(k):
    if k % 2 == 0:  # 从左到右
        s1, s2 = n - sum(s[0] for s in x), n - sum(s[1] for s in x)
    else:  # 从右到左
        s1, s2 = sum(s[0] for s in x), sum(s[1] for s in x)
    
    for i in range(s1 + 1):
        for j in range(s2 + 1):
            if 0 < i + j <= m and (i * j == 0 or i >= j):
                yield [(-i, -j), (i, j)][k % 2 == 0]

def conflict(k):
    if k > 0 and x[-1][0] == -x[-2][0] and x[-1][1] == -x[-2][1]:
        return True
    if 0 < n - sum(s[0] for s in x) < n - sum(s[1] for s in x):
        return True
    if 0 < sum(s[0] for s in x) < sum(s[1] for s in x):
        return True
    return False

def backtrack(k):
    global is_found
    if is_found:
        return
    if n - sum(s[0] for s in x) == 0 and n - sum(s[1] for s in x) == 0:
        print(x)
        is_found = True
    else:
        for state in get_states(k):
            x.append(state)
            if not conflict(k):
                backtrack(k + 1)
            x.pop()

backtrack(0)

三、回溯算法模板总结

3.1 子集树模板

适用于0-1背包、组合等问题:

def backtrack(k):
    if k >= n:  # 到达叶子节点
        process_solution()
    else:
        for i in [0, 1]:  # 两种状态
            x.append(i)
            if not conflict(k):
                backtrack(k + 1)
            x.pop()

3.2 排列树模板

适用于全排列、TSP、作业调度等问题:

def backtrack(k):
    if k >= n:
        process_solution()
    else:
        for i in range(k, n):
            x[k], x[i] = x[i], x[k]
            if not conflict(k):
                backtrack(k + 1)
            x[i], x[k] = x[k], x[i]

3.3 剪枝策略

提高回溯效率的关键:

  1. 约束剪枝:不满足约束条件时剪枝
  2. 限界剪枝:当前解不可能优于已知最优解时剪枝
  3. 对称剪枝:利用问题的对称性去重
  4. 提前判断:剩余资源不足以完成解时剪枝

四、回溯算法的应用场景

回溯算法适用于以下类型的问题:

  1. 组合问题:从集合中选取元素的组合
  2. 排列问题:元素的全排列
  3. 子集问题:集合的所有子集
  4. 棋盘问题:N皇后、骑士巡游
  5. 图论问题:图的着色、哈密顿路径
  6. 路径搜索:迷宫、机器人运动
  7. 调度问题:作业调度、课程安排
  8. 资源分配:背包、硬币找零

五、回溯算法的优化

  1. 搜索顺序优化:分支少的节点优先搜索
  2. 约束传播:提前排除不可能的解
  3. 对称性剪枝:避免重复解
  4. 记忆化搜索:存储中间结果避免重复计算
  5. 启发式搜索:优先搜索可能产生解的分支

六、总结

回溯算法是一种通用的问题求解方法,通过深度优先搜索系统地遍历解空间,利用剪枝函数避免无效搜索。它的优点是能够找到所有解或最优解,缺点是当问题规模较大时可能面临组合爆炸。在实际应用中,需要根据问题特点设计合适的解空间表示、剪枝策略和搜索顺序,才能高效地解决问题。

通过以上15个经典案例,我们可以看到回溯算法的强大表达能力和广泛应用领域。掌握回溯算法的模板和技巧,对于解决算法竞赛和实际工作中的复杂问题都有很大帮助。