在数据集中查找每个组中的前 3 个项目

Find top 3 items in each group in a dataset

有一个电影数据集。我想找出每年排名前三的流派(当年电影数量最多的流派)。 数据集摘录如下:

      year      genre  imdb_title_id

 19   1894    Romance              1
 29   1906  Biography              1
 31   1906      Crime              1
 33   1906      Drama              1
 58   1911      Drama              4
 73   1911        War              2
 52   1911  Adventure              1
 60   1911    Fantasy              1
 62   1911    History              1
 83   1912      Drama              5
 87   1912    History              2
 79   1912  Biography              1
 81   1912      Crime              1
 91   1912    Mystery              1
 98   1912        War              1
 108  1913      Drama             11
 106  1913      Crime              4
 110  1913    Fantasy              3
 102  1913  Adventure              2
 113  1913     Horror              2

如何在pandas中进行这种操作?我试过 nlargest 但没有得到正确的结果。 这种情况的预期输出应该是这样的:

19   1894    Romance              1
29   1906  Biography              1
31   1906      Crime              1
33   1906      Drama              1
58   1911      Drama              4
73   1911        War              2
52   1911  Adventure              1
83   1912      Drama              5
87   1912    History              2
79   1912  Biography              1
108  1913      Drama             11
106  1913      Crime              4
110  1913    Fantasy              3

我认为它有效:

df = df.sort_values(["imdb_title_id"], ascending=False)
df = df.groupby("year", as_index=False).agg({"genre": lambda x: list(x)[:3], "imdb_title_id": lambda x: list(x)[:3]})
result = df.explode("genre", ignore_index=True)
result["imdb_title_id"] = df.explode("imdb_title_id")["imdb_title_id"].values

但可以找到更好的方法。

nlargest() 应该 'just work' 但这里有一些示例代码来处理邪恶的索引问题。

top3_idx = df.groupby("year")["imdb_title_id"].nlargest(3).droplevel(0).index
top3_df = df.iloc[top3_idx]

基本上你按年份获得最大的 nlargest 然后使用索引值来过滤你的数据框。