加速排列

Speed up permutation

我有下一个任务。

给定整数 n (1 <= n <= 1000000) 和 k (1 <= k <= n)。

需要找到整数1, 2, 3, ..., n的任何排列p使得排列中每两个连续整数之间的绝对差为>= k,即对于排列p所有 i.

都需要 abs(p[i] - p[i + 1]) >= k

如果给定 nk 不存在这样的排列,则输出 Impossible.

原来的在线任务是波斯语,所以我不提供 link。

我已经实现了下一个代码。但是解决上述任务很慢。我怎样才能提高它的速度?

from itertools import permutations

n,k=input('').split(' ')

n=int(n);k=int(k)

def check(n,k):
    n=list(n)
    N=n[:]
    result=[]
    b=n.pop(0)
    while n:
        if abs(b-n[0]) >= k:
            result.append(True)
        b=n.pop(0)
    if len(result)+1 == len(N) and all(result):
        return True
    else:
        return False

def solver(n,k):
    return (i
        for i in (permutations(range(1,n+1)))
        if check(i,k)
    )
try:
    aaa=next(solver(n,k))
    for i in aaa:
        print(i,end=' ')
except :
    print("Impossible")

因为最简单的优化是通过使用非常流行的 NumPy 库来改进 check(...) 功能,只需在通过 python -m pip install numpy 命令行使用下一个代码之前安装 NumPy 一次.

Try it online!

def check(n, k):
    import numpy as np
    return np.abs(np.diff(n)).min() >= k

但是如果您的任务的解决方案应该 really-really 快得多,那么我实施了完整的下一个新解决方案,对于大型 n(100 万甚至更多)应该非常快。此外,我的下一个解决方案不使用任何额外的库,例如 numpy.

主要求解函数有 solve2(n, k) 它 returns 列出正确答案(正确性是 double-checked by assert),或者答案不存在它 returns None。此求解函数的用法示例是 test2() 函数,例如它为所有小的 nk 组合创建答案。如果求解函数 returns None 测试函数在打印前将其转换为字符串 Impossible

您必须按照您的案例所需的方式实现您自己的 test2() 变体,例如像您在原始源代码中所做的那样从用户那里获取输入 nk , 我刚刚做了 test2() 如何使用我的代码的例子。

Try it online!

def solve2(n, k):
    if k <= 1 or n <= 1:
        p = list(range(n))
    elif n < 2 * k - 1:
        p = None
    elif 2 * k - 1 <= n <= 2 * k + 1:
        p = [None] * n
        cnt, klo = 0, -1
        if n == 2 * k + 1:
            p[0], p[1], p[2], cnt, klo = 0, k, 2 * k, 3, 0
        for i in range(k - 1, klo, -1):
            p[cnt] = i
            cnt += 1
            if cnt >= n:
                break
            p[cnt] = i + k
            cnt += 1
    else:
        if 2 * (k + 1) <= n <= 3 * (k + 1):
            kst = k + 1
        else:
            kst = k
        p = [None] * n
        cnt = 0
        for i in range(kst):
            for j in range(i, n, kst):
                p[cnt] = j
                cnt += 1

    if p is not None:
        assert len(p) == n, (len(p), n)
        p = [(e + 1) for e in p]
        assert all(abs(f - s) >= k for f, s in zip(p[:-1], p[1:]))

    return p

def test2():
    for n in range(1, 16):
        for k in range(1, n + 1):
            answer = solve2(n, k)
            answer = answer if answer is not None else 'Impossible'
            print(n, k, answer, end = '  |  ', flush = True)

test2()