TailRec 和 State monad 的组成
Composition of TailRec and State monad
为了简化我的问题,我将从一个学术示例开始,即 ackermann 函数。
我使用以下递归的简单实现:
def a(m: BigInt, n: BigInt): BigInt = {
if (m == 0) {
n + 1
} else if (m > 0 && n == 0) a(m - 1, 1)
else a(m - 1, a(m, n - 1))
}
这不是最优的,很快就会在堆栈溢出中结束。
所以我构建了一个新的实现,它使用标准 scala 库中的 TailRec 并给出了它:
import scala.util.control.TailCalls._
private[this] def a_impl(m: BigInt, n: BigInt): TailRec[BigInt] = {
if (m == 0) {
done(n + 1)
} else if (m > 0 && n == 0) tailcall(a_impl(m - 1, 1))
else
for {
x <- tailcall(a_impl(m, n - 1))
y <- tailcall(a_impl(m - 1, x))
} yield y
}
def a(m: BigInt, n: BigInt): BigInt = {
a_impl(m, n).result
}
它可以工作,但速度很慢。
所以我构建了一个使用 State monad 的新实现,但我又一次失去了终端递归。
type Memo = Map[(BigInt, BigInt), BigInt]
private[this] def a_impl(m: BigInt, n: BigInt): State[Memo, BigInt] = {
if (m == 0) {
State.init(n + 1)
} else {
for {
memoed <- State.gets { memo: Memo => memo get (m, n) }
res <- memoed match {
case Some(ack) => State.init[Memo, BigInt](ack)
case None =>
if (m > 0 && n == 0) for {
a <- a_impl(m - 1, 1)
_ <- State.update { memo: Memo => memo + ((m, n) -> a) }
} yield a
else for {
a <- a_impl(m, n - 1)
b <- a_impl(m - 1, a)
_ <- State.update { memo: Memo => memo + ((m, n) -> b) }
} yield b
}
} yield res
}
}
def a(m: BigInt, n: BigInt): BigInt = {
a_impl(m, n) eval (Map())
}
所以我的问题是,如何同时使用 State 和 TailRec?
我看过 Monad Transformer 的概念,但我真的不知道如何在我的示例中使用它。
我什至不知道该使用哪种类型,我可以在那个和这个之间做出选择:
type TailRecWithState = TailRec[State[Memo, BigInt]]
// or
type StateWithTailRec = State[Memo, TailRec[BigInt]]
你能帮我指出这个例子的正确方向吗(然后我会处理我的实际案例)?
我知道至少在猫中,State[S, A]
是 StateT[Eval, S, A]
的类型别名,其中 Eval
与 TailRec
完全符合您的要求 - 堆栈-安全延迟执行。这对我来说很好用:
import cats._, cats.data._, cats.implicits._
type Memo = Map[(BigInt, BigInt), BigInt]
private[this] def a_impl(m: BigInt, n: BigInt): State[Memo, BigInt] = {
if (m == 0) {
State.pure(n + 1)
} else {
for {
memoed <- State.inspect[Memo, Option[BigInt]](s => s.get((m, n)))
res <- memoed match {
case Some(x) => State.pure[Memo, BigInt](x)
case None => {
if (n == 0) for {
a <- a_impl(m - 1, 1)
_ <- State.modify[Memo](s => s + ((m, n) -> a))
} yield a
else for {
a <- a_impl(m, n - 1)
b <- a_impl(m - 1, a)
_ <- State.modify[Memo](s => s + ((m, n) -> b))
} yield b
}
}
} yield res
}
}
def a(m: BigInt, n: BigInt): BigInt = {
a_impl(m, n).runA(Map()).value
}
我猜 scalaz 可能也有一些类似的 StateT
和 Eval
,尽管我不熟悉这个库。
为了简化我的问题,我将从一个学术示例开始,即 ackermann 函数。
我使用以下递归的简单实现:
def a(m: BigInt, n: BigInt): BigInt = {
if (m == 0) {
n + 1
} else if (m > 0 && n == 0) a(m - 1, 1)
else a(m - 1, a(m, n - 1))
}
这不是最优的,很快就会在堆栈溢出中结束。 所以我构建了一个新的实现,它使用标准 scala 库中的 TailRec 并给出了它:
import scala.util.control.TailCalls._
private[this] def a_impl(m: BigInt, n: BigInt): TailRec[BigInt] = {
if (m == 0) {
done(n + 1)
} else if (m > 0 && n == 0) tailcall(a_impl(m - 1, 1))
else
for {
x <- tailcall(a_impl(m, n - 1))
y <- tailcall(a_impl(m - 1, x))
} yield y
}
def a(m: BigInt, n: BigInt): BigInt = {
a_impl(m, n).result
}
它可以工作,但速度很慢。 所以我构建了一个使用 State monad 的新实现,但我又一次失去了终端递归。
type Memo = Map[(BigInt, BigInt), BigInt]
private[this] def a_impl(m: BigInt, n: BigInt): State[Memo, BigInt] = {
if (m == 0) {
State.init(n + 1)
} else {
for {
memoed <- State.gets { memo: Memo => memo get (m, n) }
res <- memoed match {
case Some(ack) => State.init[Memo, BigInt](ack)
case None =>
if (m > 0 && n == 0) for {
a <- a_impl(m - 1, 1)
_ <- State.update { memo: Memo => memo + ((m, n) -> a) }
} yield a
else for {
a <- a_impl(m, n - 1)
b <- a_impl(m - 1, a)
_ <- State.update { memo: Memo => memo + ((m, n) -> b) }
} yield b
}
} yield res
}
}
def a(m: BigInt, n: BigInt): BigInt = {
a_impl(m, n) eval (Map())
}
所以我的问题是,如何同时使用 State 和 TailRec?
我看过 Monad Transformer 的概念,但我真的不知道如何在我的示例中使用它。 我什至不知道该使用哪种类型,我可以在那个和这个之间做出选择:
type TailRecWithState = TailRec[State[Memo, BigInt]]
// or
type StateWithTailRec = State[Memo, TailRec[BigInt]]
你能帮我指出这个例子的正确方向吗(然后我会处理我的实际案例)?
我知道至少在猫中,State[S, A]
是 StateT[Eval, S, A]
的类型别名,其中 Eval
与 TailRec
完全符合您的要求 - 堆栈-安全延迟执行。这对我来说很好用:
import cats._, cats.data._, cats.implicits._
type Memo = Map[(BigInt, BigInt), BigInt]
private[this] def a_impl(m: BigInt, n: BigInt): State[Memo, BigInt] = {
if (m == 0) {
State.pure(n + 1)
} else {
for {
memoed <- State.inspect[Memo, Option[BigInt]](s => s.get((m, n)))
res <- memoed match {
case Some(x) => State.pure[Memo, BigInt](x)
case None => {
if (n == 0) for {
a <- a_impl(m - 1, 1)
_ <- State.modify[Memo](s => s + ((m, n) -> a))
} yield a
else for {
a <- a_impl(m, n - 1)
b <- a_impl(m - 1, a)
_ <- State.modify[Memo](s => s + ((m, n) -> b))
} yield b
}
}
} yield res
}
}
def a(m: BigInt, n: BigInt): BigInt = {
a_impl(m, n).runA(Map()).value
}
我猜 scalaz 可能也有一些类似的 StateT
和 Eval
,尽管我不熟悉这个库。