快速排序是常见的排序算法,可分为 partition 和 recursion 两个部分,至于 partition 的具体实现,有一些不同的方式可选。

本文使用 Python 且不对数据的合法性进行判断,前提要求是对一个列表类型的数组 nums 进行原地排序。

# 先来一个最简单的,其实用的是 Timsort
nums.sort()

方法1:填坑

也就是教材上使用的方法

# 不同的快排写法,区别都在于 partition
def partition(nums, l, r):
    pivot = nums[l]
    while l < r:
        while l < r and nums[r] > pivot:
            r -= 1
        nums[l] = nums[r]
        while l < r and nums[l] <= pivot:
            l += 1
        nums[r] = nums[l]
    nums[l] = pivot
    return l

# 分治,递归
def q_sort(nums, l, r):
    if l >= r:
        return
    pivot_idx = partition(nums, l, r)
    q_sort(nums, l, pivot_idx-1)
    q_sort(nums, pivot_idx+1, r)


def quick_sort(nums):
    q_sort(nums, 0, len(nums)-1)

方法2:交换

在方法1的思路上,修改为交换两个位置的元素

def partition(nums, l, r):
    pivot = nums[l]
    start = l
    while l < r:
        while l < r and nums[r] > pivot:
            r -= 1
        while l < r and nums[l] <= pivot:
            l += 1
        if l < r:
            nums[l], nums[r] = nums[r], nums[l]
    nums[start], nums[l] = nums[l], nums[start]
    return l

方法3:区间维护

《算法导论》上面的写法:通过从左到右一次遍历,两个指针,比较并交换元素来维护区间 [l, i](i, j) 中的元素小于/大于 pivot

def partition(nums, l, r):
    pivot = nums[r]
    i = l - 1
    for j in range(l, r):
        if nums[j] <= pivot:
            i += 1
            nums[i], nums[j] = nums[j], nums[i]
    nums[i+1], nums[r] = nums[r], nums[i+1]
    return i+1

其实还有一些其它的写法,不过大同小异,掌握核心思路就行。