julia 数组 select 前 x 行按组

julia arrays select the first x rows by group

使用 julia,我想 select 每组数组的前 x 行。

在下面的示例中,我想要第二列等于 1.0 的前两行,然后是第二列等于 2.0 的前两行,等等。

XX = [repeat([1.0], 6) vcat(repeat([1.0], 3), repeat([2.0], 3))]
XX2 = [repeat([2.0], 6) vcat(repeat([3.0], 3), repeat([4.0], 3))]
beg = [XX;XX2]

> 12×2 Matrix{Float64}:
>  1.0  1.0
>  1.0  1.0
>  1.0  1.0
>  1.0  2.0
>  1.0  2.0
>  1.0  2.0
>  2.0  3.0
>  2.0  3.0
>  2.0  3.0
>  2.0  4.0
>  2.0  4.0
>  2.0  4.0

最终数组如下所示:

8×2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0
 1.0  2.0
 1.0  2.0
 2.0  3.0
 2.0  3.0
 2.0  4.0
 2.0  4.0

我使用以下代码,但我不确定是否有更简单的方法(一个函数)已经以更有效的方式完成了?

x = []
for val in unique(beg[:,2])
    x = append!(x, findfirst(beg[:,2].==val))
end
idx = sort([x; x.+1])
final = beg[idx, :] 

假设您的数据:

  • 已排序(即组正在形成连续块)
  • 每组保证至少有两个元素

(您的代码假设两者都有)

然后您可以通过以下方式生成您想要的 idx 过滤器:

idx == [i for i in axes(beg, 1) if i < 3 || beg[i, 2] != beg[i-1, 2] || beg[i, 2] != beg[i-2, 2]]

如果您不能假设以上任一情况,请发表评论,我可以提供更通用的解决方案。

编辑

下面是一个没有使用任何外部包的例子:

julia> using Random

julia> XX = [repeat([1.0], 6) vcat(repeat([1.0], 3), repeat([2.0], 3))]
6×2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0
 1.0  1.0
 1.0  2.0
 1.0  2.0
 1.0  2.0

julia> XX2 = [repeat([2.0], 7) vcat(repeat([3.0], 3), repeat([4.0], 3), 5.0)] # last group has length 1
7×2 Matrix{Float64}:
 2.0  3.0
 2.0  3.0
 2.0  3.0
 2.0  4.0
 2.0  4.0
 2.0  4.0
 2.0  5.0

julia> beg = [XX;XX2][randperm(13), :] # shuffle groups so they are not in order
13×2 Matrix{Float64}:
 2.0  3.0
 1.0  2.0
 2.0  4.0
 2.0  3.0
 2.0  4.0
 2.0  5.0
 2.0  3.0
 1.0  2.0
 1.0  2.0
 1.0  1.0
 1.0  1.0
 2.0  4.0
 1.0  1.0

julia> x = Dict{Float64, Vector{Int}}() # this will store indices per group
Dict{Float64, Vector{Int64}}()

julia> for (i, v) in enumerate(beg[:, 2]) # collect the indices
           push!(get!(x, v, Int[]), i)
       end

julia> x
Dict{Float64, Vector{Int64}} with 5 entries:
  5.0 => [6]
  4.0 => [3, 5, 12]
  2.0 => [2, 8, 9]
  3.0 => [1, 4, 7]
  1.0 => [10, 11, 13]

julia> idx = sort!(mapreduce(x -> first(x, 2), vcat, values(x))) # get first two indices per group in ascending order
9-element Vector{Int64}:
  1
  2
  3
  4
  5
  6
  8
 10
 11