使用聚合的 Scala 并行频率计算不起作用

Scala parallel frequency calculation using aggregate doesn't work

我正在通过书中的练习来学习 Scala "Scala for the Impatient"。请参阅以下问题以及我的答案和代码。我想知道我的回答是否正确。该代码也不起作用(所有频率均为 1)。错误在哪里?

Q10: Harry Hacker reads a file into a string and wants to use a parallel collection to update the letter frequencies concurrently on portions of the string. He uses the following code:

val frequencies = new scala.collection.mutable.HashMap[Char, Int]
for (c <- str.par) frequencies(c) = frequencies.getOrElse(c, 0) + 1

Why is this a terrible idea? How can he really parallelize the computation?

我的回答: 这不是一个好主意,因为如果 2 个线程同时更新相同的频率,结果是不确定的。

我的代码:

def parFrequency(str: String) = {
  str.par.aggregate(Map[Char, Int]())((m, c) => { m + (c -> (m.getOrElse(c, 0) + 1)) }, _ ++ _)
}

单元测试:

"Method parFrequency" should "return the frequency of each character in a string" in {
  val freq = parFrequency("harry hacker")

  freq should have size 8

  freq('h') should be(2) // fails
  freq('a') should be(2)
  freq('r') should be(3)
  freq('y') should be(1)
  freq(' ') should be(1)
  freq('c') should be(1)
  freq('k') should be(1)
  freq('e') should be(1)
}

编辑: 阅读 this 线程后,我更新了代码。现在,如果单独 运行 测试有效,但如果 运行 作为套件测试失败。

def parFrequency(str: String) = {
  val freq = ImmutableHashMap[Char, Int]()
  str.par.aggregate(freq)((_, c) => ImmutableHashMap(c -> 1), (m1, m2) => m1.merged(m2)({
      case ((k, v1), (_, v2)) => (k, v1 + v2)
  }))
}

编辑 2: 请参阅下面我的解决方案。

++ 不组合相同键的值。因此,当您合并映射时,您将获得(对于共享键)其中一个值(在本例中始终为 1),而不是值的总和。

这个有效:

def parFrequency(str: String) = {
  str.par.aggregate(Map[Char, Int]())((m, c) => { m + (c -> (m.getOrElse(c, 0) + 1)) },
  (a,b) => b.foldLeft(a){case (acc, (k,v))=> acc updated (k, acc.getOrElse(k,0) + v) })
} 

val freq = parFrequency("harry hacker")
//> Map(e -> 1, y -> 1, a -> 2,   -> 1, c -> 1, h -> 2, r -> 3, k -> 1)

foldLeft 遍历其中一张地图,用找到的 key/values 更新另一张地图。

您在第一种情况下遇到麻烦,因为您自己检测到的是 ++ 运算符,它只是连接,丢弃了同一键的第二次出现。

现在在第二种情况下,您有 (_, c) => ImmutableHashMap(c -> 1),它只是删除了我在 seqop 阶段找到的所有字符。

我的建议是使用特殊的组合操作扩展 Map 类型,像 HashMap 中的 merged 一样工作,并在 seqop 阶段保留第一个示例的收集:

implicit class MapUnionOps[K, V](m1: Map[K, V]) {
  def unionWith[V1 >: V](m2: Map[K, V1])(f: (V1, V1) => V1): Map[K, V1] = {
    val kv1 = m1.filterKeys(!m2.contains(_))
    val kv2 = m2.filterKeys(!m1.contains(_))
    val common = (m1.keySet & m2.keySet).toSeq map (k => (k, f(m1(k), m2(k))))
    (common ++ kv1 ++ kv2).toMap
  }
}

def parFrequency(str: String) = {
  str.par.aggregate(Map[Char, Int]())((m, c) => {m + (c -> (m.getOrElse(c, 0) + 1))}, (m1, m2) =>  (m1 unionWith m2)(_ + _))
}

或者您可以使用 Paul 的回答中的 fold 解决方案,但为了每次合并的性能更好,请选择较小的地图来遍历:

implicit class MapUnionOps[K, V](m1: Map[K, V]) {
  def unionWith(m2: Map[K, V])(f: (V, V) => V): Map[K, V] =
    if (m2.size > m1.size) m2.unionWith(m1)(f)
    else m2.foldLeft(m1) {
      case (acc, (k, v)) => acc + (k -> acc.get(k).fold(v)(f(v, _)))
    }
}

这似乎有效。我比这里提出的其他解决方案更喜欢它,因为:

  1. 它的代码比 implicit class 少很多,比使用 getOrElsefoldLeft 的代码少得多。
  2. 它使用 API 中的 merged 函数,该函数旨在执行我想要的操作。
  3. 这是我自己的解决方案:)

    def parFrequency(str: String) = {
      val freq = ImmutableHashMap[Char, Int]()
      str.par.aggregate(freq)((_, c) => ImmutableHashMap(c -> 1), _.merged(_) {
        case ((k, v1), (_, v2)) => (k, v1 + v2)
      })
    }
    

感谢您花时间帮助我。