如何验证一个洗牌算法是否均匀?

How to verify that a shuffling algorithm is uniform?

我有一个简单的 Python 实现 Knuth's shuffling algorithm:

def knuth_shuffle(ar):
    num = len(ar)
    for i in range(num):
        index = random.randint(0, i)
        ar[i], ar[index] = ar[index], ar[i]
    return ar

如何测试(使用 scipy 或任何其他包)洗牌确实是均匀的?我找到了一些相关的帖子 (1, 2),但它们没有回答我的问题。如果能了解一般如何执行此类检查,那就太好了。

编辑:

如评论中的Paul Hankin,我原来的测试只检查每个元素落入每个位置的概率,而不是所有排列都是等概率的,这是一个更强的要求。下面的代码片段计算了每个排列的频率,这是我们应该关注的:

import math
import random

def knuth_shuffle(ar):
    num = len(ar)
    for i in range(num):
        index = random.randint(0, i)
        ar[i], ar[index] = ar[index], ar[i]
    return ar

# This function computes a unique index for a given permutation
# Adapted from https://www.jaapsch.net/puzzles/compindx.htm#perm
def permutation_index(permutation):
    n = len(permutation)
    t = 0
    for i in range(n):
      t = t * (n - i)
      for j in range(i + 1, n):
        if permutation[i] > permutation[j]:
            t += 1
    return t

N = 6  # Test list size
T = 1000  # Trials / number of permutations

random.seed(100)
n_perm = math.factorial(N)
trials = T * n_perm
ar = list(range(N))
freq = [0] * n_perm
for _ in range(trials):
    ar_shuffle = ar.copy()
    knuth_shuffle(ar_shuffle)
    freq[permutation_index(ar_shuffle)] += 1

如果随机播放是均匀的,则生成的 freq 向量的值应根据具有 T * N! 次试验和成功概率 1 / (N!) 的二项分布进行分布。这是前面示例的分布估计图(使用 Seaborn 完成),其中频率值应在 1000 左右:

我认为看起来 不错 但是同样,对于定量结果,您需要更深入的统计分析,例如 Pearson's chi-squared test, as suggested by David Eisenstat.


原始答案:

我将在这里提出一些基本的想法,但我没有最强大的统计背景,所以有人可能想要补充或纠正任何错误的地方。

您可以制作一个每个值落入每个位置的频率矩阵以进行多次试验:

def knuth_shuffle(ar):
    num = len(ar)
    for i in range(num):
        index = random.randint(0, i)
        ar[i], ar[index] = ar[index], ar[i]
    return ar

N = 100  # Test list size
T = 10000  # Number of trials
ar = list(range(N))
freq = [[0] * N for _ in range(N)]

for _ in range(T):
    ar_shuffle = ar.copy()
    kunth_shuffle(ar_shuffle)
    for i, j in enumerate(ar_shuffle):
        freq[i][j] += 1

一旦可以做到这一点,您可以采用多种方法。一个简单的想法是,如果洗牌是均匀的,freq / T 应该趋向于 1 / N,因为 T 趋向于无穷大。因此,您可以只使用 T 的 "very big" 值,然后看到这些值是 "close enough"。或者检查freq / T - 1 / N的标准差是"small enough".

这些"close enough"和"small enough"虽然不是很扎实的概念。扎根分析需要更多的统计工具。我认为您需要 test the hipothesis that each frequency value is sampled from a binomial distributionT 试验 1 / N 成功概率。正如我所说,没有完整解释的背景,这可能不是它的地方,但如果你真的需要一个彻底的分析,你可以阅读这个主题。

如果您按照给定的固定顺序随机打乱相同的物品,那么在打乱的物品中 一个 固定位置的每个物品的数量应该趋向于相同的值。

下面我将列表 0..9 打乱几次并打印输出:

from random import shuffle  # Uses Fischer-Yates

tries = 1_000_000
intcount = 10
first_position_counts = {n:0 for n in ints}
ints = range(intcount)
for _ in range(tries):
    lst = list(ints)   # [0, 1, ...9] In that order
    shuffle(lst)
    first_position_counts[lst[0]] += 1

print(f'{tries} shuffles of the ints 0..{intcount-1} should have each int \n',
      'appear in the first position {tries/intcount} times.')
for item in first_position_counts.items():
    print(' %i: %5i' % item)

运行一旦你可能得到类似的东西:

 0: 99947
 1: 100522
 2: 99828
 3: 100123
 4: 99582
 5: 99635
 6: 99991
 7: 100108
 8: 100172
 9: 100092

再一次:

 0: 100049
 1: 99918
 2: 100053
 3: 100285
 4: 100293
 5: 100034
 6: 99861
 7: 99584
 8: 100055
 9: 99868

现在,如果您有数千个项目要洗牌,那么它们应该以 n! 排列之一结束,但是 n! 会很快变大 ;如果它是 "comparable",肯定大于随机数生成器的可能范围,那么它就会崩溃。

您可以通过将所有可能的随机数序列注入 knuth_shuffle,然后验证您恰好得到每个排列一次来准确检查这一点。

此代码执行此操作:

import collections
import itertools
import random

def knuth_shuffle(ar, R=random.randint):
    num = len(ar)
    for i in range(num):
        index = R(0, i)
        ar[i], ar[index] = ar[index], ar[i]
    return ar

def fact(i):
    r = 1
    while i > 1:
        r *= i
        i -= 1
    return r

def all_random_seqs(N):
    for r in range(fact(N)):
        seq = []
        for i in range(N):
            seq.append(r % (i+1))
            r //= (i+1)
        it = iter(seq)
        yield lambda x, y: next(it)

for N in range(1, 6):
    print N
    results = collections.Counter()
    for R in all_random_seqs(N):
        a = list('ABCDEFG'[:N])
        knuth_shuffle(a, R)
        results[''.join(a)] += 1
    print 'checking...'
    if len(results) != fact(N):
        print 'N=%d. Not enough results. %s' % (N, results)
    if any(c > 1 for c in results.itervalues()):
        print 'N=%d. Not all permutations unique. %s' % (N, results)
    if any(sorted(c) != list('ABCDEFG'[:N]) for c in results.iterkeys()):
        print 'N=%d. Some permutations are illegal. %s' % (N, results)

此代码检查大小为 1、2、3、4、5 的输入列表的确切正确性。您可以在 N 之前更进一步!变得太大了。

您还需要使用 random.randint 对代码版本执行健全性检查(例如生成 500 次 'ABCD' 随机排列,并确保每个排列至少得到一次)。