Julia:分区迭代器上的并行 for 循环

Julia: Parallel for loop over partitions iterator

所以我正在尝试遍历某些分区的列表,比如 1:n 用于 13 到 21 之间的一些 n。我理想中想要的代码 运行看起来像这样:

valid_num = @parallel (+) for p in partitions(1:n)
  int(is_valid(p))
end

println(valid_num)

这将使用 @parallel for 来映射减少我的问题。例如,将此与 Julia 文档中的示例进行比较:

nheads = @parallel (+) for i=1:200000000
  Int(rand(Bool))
end

但是,如果我尝试修改循环,我会收到以下错误:

ERROR: `getindex` has no method matching getindex(::SetPartitions{UnitRange{Int64}}, ::Int64)
 in anonymous at no file:1433
 in anonymous at multi.jl:1279
 in run_work_thunk at multi.jl:621
 in run_work_thunk at multi.jl:630
 in anonymous at task.jl:6

我认为这是因为我正在尝试迭代不是 1:n 形式的东西(编辑:我认为这是因为如果 p=partitions(1:n) 则不能调用 p[3]) .

我试过使用 pmap 来解决这个问题,但是因为分区的数量会变得非常大,非常快(1:13 有超过 250 万个分区,当我get to 1:21 things will be huge), 构建如此大的数组成为一个问题。我把它放在 运行 晚上了,但它仍然没有完成。

有没有人对我如何在 Julia 中有效地执行此操作有任何建议?我可以访问大约 30 核计算机,我的任务似乎很容易并行化,所以如果有人知道在 Julia 中执行此操作的好方法,我将不胜感激。

非常感谢!

一种方法是将问题分成不太大而无法实现的部分,然后并行处理每个部分中的项目,例如如下:

function my_take(iter,state,n)
    i = n
    arr = Array[]
    while !done(iter,state) && (i>0)
        a,state = next(iter,state)
        push!(arr,a)
        i = i-1
    end
    return arr, state
end

function get_part(npart,npar)
    valid_num = 0
    p = partitions(1:npart)
    s = start(p)
    while !done(p,s)
        arr,s = my_take(p,s,npar)
        valid_num += @parallel (+) for a in arr
            length(a)
        end
    end
    return valid_num
end

valid_num = @time get_part(10,30)

我打算使用 take() 方法从迭代器中实现最多 npar 项,但 take() 似乎已被弃用,所以我包含了我自己的实现已致电 my_take()。因此 getPart() 函数使用 my_take() 一次获取最多 npar 个分区并对它们进行计算。在这种情况下,计算只是将它们的长度相加,因为我没有 OP 的 is_valid() 函数的代码。 get_part() 然后 returns 结果。

因为 length() 计算不是很耗时,这段代码在并行处理器上 运行 实际上比在单处理器上慢:

$ julia -p 1 parpart.jl
elapsed time: 10.708567515 seconds (373025568 bytes allocated, 6.79% gc time)

$ julia -p 2 parpart.jl
elapsed time: 15.70633439 seconds (548394872 bytes allocated, 9.14% gc time)

或者,pmap() 可以用于每个问题而不是并行 for 循环。

关于内存问题,当我 运行 Julia 有 4 个工作进程时,从 partitions(1:10) 实现 30 个项目在我的电脑上占用了将近 1 GB 的内存,所以我希望实现一个小子集partitions(1:21) 将需要大量内存。在尝试这样的计算之前,可能需要估计需要多少内存以查看是否完全可行。

关于计算时间,注意:

julia> length(partitions(1:10))
115975

julia> length(partitions(1:21))
474869816156751

... 因此,即使在 30 个内核上进行高效并行处理也可能不足以在合理的时间内解决更大的问题。

下面的代码给出了 511,一组 10 个大小为 2 的分区数。

using Iterators
s = [1,2,3,4,5,6,7,8,9,10]
is_valid(p) = length(p)==2
valid_num = @parallel (+) for i = 1:30
  sum(map(is_valid, takenth(chain(1:29,drop(partitions(s), i-1)), 30)))
end

此解决方案结合了 takenth、drop 和 chain 迭代器,以获得与下面上一个答案下的 take_every 迭代器相同的效果。请注意,在此解决方案中,每个进程都必须计算每个分区。但是,因为每个进程对 drop 使用不同的参数,所以没有两个进程会在同一分区上调用 is_valid。

除非您想做大量数学运算来弄清楚如何实际跳过分区,否则无法避免至少在一个进程上按顺序计算分区。我认为 Simon 的回答是在一个进程上执行此操作并分配分区。我的要求每个工作进程自己计算分区,这意味着计算是重复的。但是,它是并行复制的,这(如果您实际上有 30 个处理器)不会花费您的时间。

这里是关于如何实际计算分区迭代器的资源:http://www.informatik.uni-ulm.de/ni/Lehre/WS03/DMM/Software/partitions.pdf

上一个答案(比必要的更复杂)

我在写我的时候注意到西蒙的回答。我们的解决方案似乎与我相似,除了我使用迭代器来避免将分区存储在内存中。我不确定对于什么大小的设置哪个实际上会更快,但我认为同时拥有这两个选项是件好事。假设计算 is_valid 比计算分区本身花费的时间要长得多,您可以这样做:

s = [1,2,3,4]
is_valid(p) = length(p)==2
valid_num = @parallel (+) for i = 1:30
  foldl((x,y)->(x + int(is_valid(y))), 0, take_every(partitions(s), i-1, 30))
end

这给了我 7,一组 4 的大小为 2 的分区数。take_every 函数 returns 一个迭代器,returns 每 30 个分区从第 i 个开始.这是相关代码:

import Base: start, done, next
immutable TakeEvery{Itr}
  itr::Itr
  start::Any
  value::Any
  flag::Bool
  skip::Int64
end
function take_every(itr, offset, skip)
  value, state = Nothing, start(itr)
  for i = 1:(offset+1)
    if done(itr, state)
      return TakeEvery(itr, state, value, false, skip)
    end
    value, state = next(itr, state)
  end
  if done(itr, state)
    TakeEvery(itr, state, value, true, skip)
  else
    TakeEvery(itr, state, value, false, skip)
  end
end
function start{Itr}(itr::TakeEvery{Itr})
  itr.value, itr.start, itr.flag
end
function next{Itr}(itr::TakeEvery{Itr}, state)
  value, state_, flag = state
  for i=1:itr.skip
    if done(itr.itr, state_)
      return state[1], (value, state_, false)
    end
    value, state_ = next(itr.itr, state_)
  end
  if done(itr.itr, state_)
    state[1], (value, state_, !flag)
  else
    state[1], (value, state_, false)
  end
end
function done{Itr}(itr::TakeEvery{Itr}, state)
  done(itr.itr, state[2]) && !state[3]
end