为什么 'groupby(x, np.isnan)' 的行为与 'groupby(x) if key is nan' 不同?

Why does 'groupby(x, np.isnan)' behave differently to 'groupby(x) if key is nan'?

,我发现了一些我也不明白的东西。我发布这个问题主要是作为 MSeifert 的扩展,因为我们的观察结果似乎有共同的原因。

早些时候, 涉及在包含 nan 个值的序列上使用 itertools.groupby

return max((sum(1 for _ in group) for key, group in groupby(sequence) if key is nan), default=0)

但是,我在上面链接的 MSeifert 的问题上看到 this answer,它显示了我可能制定此算法的另一种方法:

return max((sum(1 for _ in group) for key, group in groupby(sequence, np.isnan)), default=0)

实验

我已经用列表和 numpy 数组测试了这两种变体。代码和结果如下:

from itertools import groupby

from numpy import nan
import numpy as np


def longest_nan_run(sequence):
    return max((sum(1 for _ in group) for key, group in groupby(sequence) if key is nan), default=0)


def longest_nan_run_2(sequence):
    return max((sum(1 for _ in group) for key, group in groupby(sequence, np.isnan)), default=0)


if __name__ == '__main__':
    nan_list = [nan, nan, nan, 0.16, 1, 0.16, 0.9999, 0.0001, 0.16, 0.101, nan, 0.16]
    nan_array = np.array(nan_list)

    print(longest_nan_run(nan_list))  # 3 - correct
    print(longest_nan_run_2(nan_list))  # 7 - incorrect
    print(longest_nan_run(nan_array))  # 0 - incorrect
    print(longest_nan_run_2(nan_array))  # 7 - incorrect

分析

谁能解释这些结果?同样,由于这个问题与 MSeifert 的相关,对他的结果的解释可能也能解释我的结果(反之亦然)。


进一步调查

为了更好地了解正在发生的事情,我尝试打印出由 groupby:

生成的组
def longest_nan_run(sequence):
    print(list(list(group) for key, group in groupby(sequence) if key is nan))
    return max((sum(1 for _ in group) for key, group in groupby(sequence) if key is nan), default=0)


def longest_nan_run_2(sequence):
    print(list(list(group) for _, group in groupby(sequence, np.isnan)))
    return max((sum(1 for _ in group) for key, group in groupby(sequence, np.isnan)), default=0)

一个根本区别(回想起来是有道理的)是​​原始函数(使用if key is nan)会过滤掉除之外的所有内容 nan 个值,因此所有生成的组将仅包含 nan 个值,如下所示:

[[nan, nan, nan], [nan]]

另一方面,修改后的函数会将所有非nan值分组到它们自己的组中,如下所示:

[[nan, nan, nan], [0.16, 1.0, 0.16, 0.99990000000000001, 0.0001, 0.16, 0.10100000000000001], [nan], [0.16]]

这解释了为什么在这两种情况下修改函数 returns 7 - 它正在考虑值作为“nan”或 "not nan" 并返回最长的连续系列要么。

这也意味着我对 groupby(sequence, keyfunc) 工作原理的假设是错误的,并且修改后的函数不是原始函数的可行替代方案。

不过,我仍然不确定 运行 原始函数在列表和数组上的结果差异。

numpy 数组中的项目访问行为与列表中的行为不同:

nan_list[0] == nan_list[1]
# False
nan_list[0] is nan_list[1]
# True

nan_array[0] == nan_array[1]
# False
nan_array[0] is nan_array[1]
# False

x = np.array([1])
x[0] == x[0]
# True
x[0] is x[0]
# False

虽然列表包含对同一对象的引用,但 numpy 数组 'contain' 仅存储一个内存区域,并在每次访问元素时动态创建新的 Python 对象。 (感谢 user2357112,指出措辞不准确的地方。)

有道理,对吧?列表返回相同的对象,数组返回不同的对象 - 显然 groupby 内部使用 is 进行比较......但是等等,这并不容易!为什么 groupby(np.array([1, 1, 1, 2, 3])) 可以正常工作?

答案隐藏在 itertools C source, line 90 shows that the function PyObject_RichCompareBool 用于比较两个键。

rcmp = PyObject_RichCompareBool(gbo->tgtkey, gbo->currkey, Py_EQ);

虽然这基本上等同于在 Python 中使用 ==,但文档指出了一个特殊性:

Note If o1 and o2 are the same object, PyObject_RichCompareBool() will always return 1 for Py_EQ and 0 for Py_NE.

这意味着实际上执行了这个比较(等效代码):

if o1 is o2:
    return True
else:
    return o1 == o2

因此对于列表,我们有相同的 nan 个对象,它们被标识为相等。相反,数组为我们提供了具有值 nan 的不同对象,这些对象与 == 进行比较 - 但 nan == nan 始终评估为 False.

好吧,我想我已经为自己描绘了一幅足够清晰的画面。

这里有两个因素在起作用:

  • 我自己对 keyfunc 论点对 groupby 的影响的误解。
  • 关于 Python 如何在数组和列表中表示 nan 值的(更有趣的)故事,在 .
  • 中有最好的解释

解释 keyfunc 因素

来自documentation on groupby:

It generates a break or new group every time the value of the key function changes

来自documentation on np.isnan:

For scalar input, the result is a new boolean with value True if the input is NaN; otherwise the value is False.

基于这两件事,我们推断当我们将keyfunc设置为np.isnan时,传递给groupyby的序列中的每个元素将被映射到True ] 或 False,取决于它是否是 nan。这意味着键函数只会在 nan 元素和非 nan 元素之间的边界处发生变化,因此 groupby 只会将序列拆分为 [=13= 的连续块] 和非 nan 元素。

相比之下,原始函数(使用groupby(sequence) if key is nan)将使用身份函数用于keyfunc(它的默认值)。这自然会导致下面解释的 nan 身份的细微差别(以及上面的链接答案),但这里的重点是 if key is nan 将过滤掉所有键入非 nan个元素。

解释 nan 身份中的细微差别

正如我在上面链接的答案中更好地解释的那样,Python 的 built-in 列表中出现的所有 nan 实例似乎都是 一个并且同一个实例。换句话说,列表中所有出现的 nan 都指向内存中的同一个位置。与此相反,nan 元素是在使用 numpy 数组时动态生成的,因此都是单独的对象。

使用以下代码进行演示:

def longest_nan_run(sequence):
    print(id(nan))
    print([id(x) for x in sequence])
    return max((sum(1 for _ in group) for key, group in groupby(sequence) if key is nan), default=0)

当我 运行 使用原始问题中定义的 list 时,我得到了这个输出(相同的元素被突出显示):

4436731128
[4436731128, 44436731128, 44436731128, 4436730432, 4435753536, 4436730432, 4436730192, 4436730048, 4436730432, 4436730552, 44436731128, 4436730432]

另一方面,数组元素在内存中的处理方式似乎非常不同:

4343850232
[4357386696, 4357386720, 4357386696, 4357386720, 4357386696, 4357386720, 4357386696, 4357386720, 4357386696,, 4357386720, 4357386696, 4357386720]

该函数似乎在内存中的两个不同位置之间交替来存储这些值。请注意,none 个元素与筛选条件中使用的 nan 相同。

案例研究

我们现在可以将我们收集到的所有这些信息应用到实验中使用的四个独立案例来解释我们的观察结果。

带列表的原始函数

在这种情况下,我们使用默认的 identity 函数作为 keyfunc,我们已经看到列表中每次出现的 nan 实际上都是同一个实例。过滤器条件 if key is nan 中使用的 nan 与列表中的 nan 元素 相同,导致 groupby 中断在适当的地方列出,只保留包含 nan 的组。这就是这个变体起作用的原因,我们得到 3.

的正确结果

带数组的原始函数

同样,我们使用默认的 identity 函数作为 keyfunc,但这次所有 nan 事件 - 包括过滤器条件中的那个 - 指向不同的对象。这意味着条件过滤器 if key is nan 将对 all 组失败。由于我们无法找到空集合的最大值,因此我们求助于默认值 0.

使用列表和数组修改函数

在这两种情况下,我们使用 np.isnan 作为 keyfunc。这将导致 groupby 将序列拆分为 nan 和非 nan 元素的连续序列。

我们实验用的list/array,最长的nan个元素的序列是[nan, nan, nan],有3个元素,最长的非[=13]个序列=] 元素是 [0.16, 1, 0.16, 0.9999, 0.0001, 0.16, 0.101],它有 7 个元素。

max 将 select 这两个序列中较长的一个,在这两种情况下 return 7