寻找完美平方算法的优化

Optimization for finding perfect-square algorithm

我正在研究的 question 是:

Find which sum of squared factors are a perfect square given a specific range. So if the range was (1..10) you would get each number's factors (all factors for 1, all factors for 2, all factors for 3 ect..) Square those factors, then add them together. Finally check if that sum is a perfect square.

我卡在 refactoring/optimization 因为我的解决方案太慢了。

这是我想出的:

def list_squared(m, n)
  ans = []
  range = (m..n)

  range.each do |i|
    factors = (1..i).select { |j| i % j == 0 }
    squares = factors.map { |k| k ** 2 }
    sum = squares.inject { |sum,x| sum + x }
    if sum == Math.sqrt(sum).floor ** 2
      all = []
      all += [i, sum]
      ans << all
    end
  end

  ans
end

这是我要放入方法中的示例:

list_squared(1, 250)

然后所需的输出将是一个数组数组,每个数组包含其平方因子之和为完美平方的数字以及这些平方因子之和:

[[1, 1], [42, 2500], [246, 84100]]

我将从介绍一些辅助方法(factorssquare?)开始,让您的代码更具可读性。

此外,我会减少范围和数组的数量以提高内存使用率。

require 'prime'

def factors(number)
  [1].tap do |factors|
    primes = number.prime_division.flat_map { |p, e| Array.new(e, p) }
    (1..primes.size).each do |i| 
      primes.combination(i).each do |combination| 
        factor = combination.inject(:*)
        factors << factor unless factors.include?(factor)
      end
    end
  end
end

def square?(number)
  square = Math.sqrt(number)
  square == square.floor
end

def list_squared(m, n)
  (m..n).map do |number|
    sum = factors(number).inject { |sum, x| sum + x ** 2 }
    [number, sum] if square?(sum)
  end.compact
end

list_squared(1, 250)

范围较窄(最多 250)的基准仅显示出微小的改进:

require 'benchmark'
n = 1_000

Benchmark.bmbm(15) do |x|
  x.report("original_list_squared :") { n.times do; original_list_squared(1, 250); end }
  x.report("improved_list_squared :") { n.times do; improved_list_squared(1, 250); end }
end

# Rehearsal -----------------------------------------------------------
# original_list_squared :   2.720000   0.010000   2.730000 (  2.741434)
# improved_list_squared :   2.590000   0.000000   2.590000 (  2.604415)
# -------------------------------------------------- total: 5.320000sec

#                               user     system      total        real
# original_list_squared :   2.710000   0.000000   2.710000 (  2.721530)
# improved_list_squared :   2.620000   0.010000   2.630000 (  2.638833)

但是具有更宽范围(高达 10000)的基准显示比原始实现更好的性能:

require 'benchmark'
n = 10

Benchmark.bmbm(15) do |x|
  x.report("original_list_squared :") { n.times do; original_list_squared(1, 10000); end }
  x.report("improved_list_squared :") { n.times do; improved_list_squared(1, 10000); end }
end

# Rehearsal -----------------------------------------------------------
# original_list_squared :  36.400000   0.160000  36.560000 ( 36.860889)
# improved_list_squared :   2.530000   0.000000   2.530000 (  2.540743)
# ------------------------------------------------- total: 39.090000sec

#                               user     system      total        real
# original_list_squared :  36.370000   0.120000  36.490000 ( 36.594130)
# improved_list_squared :   2.560000   0.010000   2.570000 (  2.581622)

tl;dr:N 越大,与原始实现相比,我的代码性能越好...

提高效率的一种方法是使用 Ruby 的内置方法 Prime::prime_division

对于任何数字 n,如果 prime_division returns 包含单个元素的数组,则该元素将是 [n,1] 并且 n 将是显示为质数。该质数有因数 n1,因此必须与非质数区别对待。

require 'prime'

def list_squared(range)
  range.each_with_object({}) do |i,h|
    facs = Prime.prime_division(i)
    ssq = 
    case facs.size
    when 1 then facs.first.first**2 + 1
    else facs.inject(0) { |tot,(a,b)| tot + b*(a**2) }
    end
    h[i] = facs if (Math.sqrt(ssq).to_i)**2 == ssq
  end
end

list_squared(1..10_000)
  #=> { 1=>[], 48=>[[2, 4], [3, 1]], 320=>[[2, 6], [5, 1]], 351=>[[3, 3], [13, 1]],
  #     486=>[[2, 1], [3, 5]], 1080=>[[2, 3], [3, 3], [5, 1]],
  #     1260=>[[2, 2], [3, 2], [5, 1], [7, 1]], 1350=>[[2, 1], [3, 3], [5, 2]],
  #     1375=>[[5, 3], [11, 1]], 1792=>[[2, 8], [7, 1]], 1836=>[[2, 2], [3, 3], [17, 1]],
  #     2070=>[[2, 1], [3, 2], [5, 1], [23, 1]], 2145=>[[3, 1], [5, 1], [11, 1], [13, 1]],
  #     2175=>[[3, 1], [5, 2], [29, 1]], 2730=>[[2, 1], [3, 1], [5, 1], [7, 1], [13, 1]],
  #     2772=>[[2, 2], [3, 2], [7, 1], [11, 1]], 3072=>[[2, 10], [3, 1]],
  #     3150=>[[2, 1], [3, 2], [5, 2], [7, 1]], 3510=>[[2, 1], [3, 3], [5, 1], [13, 1]],
  #     4104=>[[2, 3], [3, 3], [19, 1]], 4305=>[[3, 1], [5, 1], [7, 1], [41, 1]],
  #     4625=>[[5, 3], [37, 1]], 4650=>[[2, 1], [3, 1], [5, 2], [31, 1]],
  #     4655=>[[5, 1], [7, 2], [19, 1]], 4998=>[[2, 1], [3, 1], [7, 2], [17, 1]],
  #     5880=>[[2, 3], [3, 1], [5, 1], [7, 2]], 6000=>[[2, 4], [3, 1], [5, 3]],
  #     6174=>[[2, 1], [3, 2], [7, 3]], 6545=>[[5, 1], [7, 1], [11, 1], [17, 1]],
  #     7098=>[[2, 1], [3, 1], [7, 1], [13, 2]], 7128=>[[2, 3], [3, 4], [11, 1]],
  #     7182=>[[2, 1], [3, 3], [7, 1], [19, 1]], 7650=>[[2, 1], [3, 2], [5, 2], [17, 1]],
  #     7791=>[[3, 1], [7, 2], [53, 1]], 7889=>[[7, 3], [23, 1]],
  #     7956=>[[2, 2], [3, 2], [13, 1], [17, 1]],
  #     9030=>[[2, 1], [3, 1], [5, 1], [7, 1], [43, 1]],
  #     9108=>[[2, 2], [3, 2], [11, 1], [23, 1]], 9295=>[[5, 1], [11, 1], [13, 2]],
  #     9324=>[[2, 2], [3, 2], [7, 1], [37, 1]]} 

这个计算大约用了 0.15 秒。

对于i = 6174

 (2**1) * (3**2) * (7**3) #=> 6174

 1*(2**2) + 2*(3**2) + 3*(7**2) #=> 169 == 13*13 

经常解决此类问题的技巧是从试题切换到sieve。在 Python(抱歉):

def list_squared(m, n):
    factor_squared_sum = {i: 0 for i in range(m, n + 1)}
    for factor in range(1, n + 1):
        i = n - n % factor  # greatest multiple of factor less than or equal to n
        while i >= m:
            factor_squared_sum[i] += factor ** 2
            i -= factor
    return {i for (i, fss) in factor_squared_sum.items() if isqrt(fss) ** 2 == fss}


def isqrt(n):
    # from 
    x = n
    y = (x + 1) // 2
    while y < x:
        x = y
        y = (x + n // x) // 2
    return x

下一个优化是将 factor 步进到 isqrt(n),成对添加因子平方(例如,2i // 2)。