在 Scala 中是不是有一种通用的方式来记忆?

Posted

技术标签:

【中文标题】在 Scala 中是不是有一种通用的方式来记忆?【英文标题】:Is there a generic way to memoize in Scala?在 Scala 中是否有一种通用的方式来记忆? 【发布时间】:2013-04-21 20:43:56 【问题描述】:

我想记住这个:

def fib(n: Int) = if(n <= 1) 1 else fib(n-1) + fib(n-2)
println(fib(100)) // times out

所以我写了这个,这令人惊讶地编译和工作(我很惊讶,因为fib 在它的声明中引用了它自己):

case class Memo[A,B](f: A => B) extends (A => B) 
  private val cache = mutable.Map.empty[A, B]
  def apply(x: A) = cache getOrElseUpdate (x, f(x))


val fib: Memo[Int, BigInt] = Memo 
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2) 


println(fib(100))     // prints 100th fibonacci number instantly

但是当我尝试在 def 中声明 fib 时,我得到一个编译器错误:

def foo(n: Int) = 
  val fib: Memo[Int, BigInt] = Memo 
    case 0 => 0
    case 1 => 1
    case n => fib(n-1) + fib(n-2) 
  
  fib(n)
 

以上编译失败error: forward reference extends over definition of value fib case n => fib(n-1) + fib(n-2)

为什么在 def 内声明 val fib 失败,但在类/对象范围外声明有效?

为了澄清,为什么我可能想在 def 范围内声明递归记忆函数 - 这是我对子集和问题的解决方案:

/**
   * Subset sum algorithm - can we achieve sum t using elements from s?
   *
   * @param s set of integers
   * @param t target
   * @return true iff there exists a subset of s that sums to t
   */
  def subsetSum(s: Seq[Int], t: Int): Boolean = 
    val max = s.scanLeft(0)((sum, i) => (sum + i) max sum)  //max(i) =  largest sum achievable from first i elements
    val min = s.scanLeft(0)((sum, i) => (sum + i) min sum)  //min(i) = smallest sum achievable from first i elements

    val dp: Memo[(Int, Int), Boolean] = Memo          // dp(i,x) = can we achieve x using the first i elements?
      case (_, 0) => true        // 0 can always be achieved using empty set
      case (0, _) => false       // if empty set, non-zero cannot be achieved
      case (i, x) if min(i) <= x && x <= max(i) => dp(i-1, x - s(i-1)) || dp(i-1, x)  // try with/without s(i-1)
      case _ => false            // outside range otherwise
    

    dp(s.length, t)
  

【问题讨论】:

查看我的blog post 了解递归函数记忆的另一种变体。 在我向 SO 发布任何内容之前,我先用 Google 搜索它,而您的博客文章是第一个结果 :) 我同意这是“正确”的方法 - 使用 Y-combinator。但是,我认为使用我的风格并利用 lazy val 看起来比为每个函数有 2 个定义(递归定义和 Y 组合定义)更干净。看起来这 [looks](1) [1] 有多干净:github.com/pathikrit/scalgos/blob/master/src/main/scala/com/… 我对上述问题中的一些简洁语法感到困惑(特别是案例类对“扩展(A => B)”的使用。我发布了一个关于它的问题:***.com/questions/19548103/… Map:***.com/questions/6806123/…带来的并发问题慎用此模式 正文中提出的问题和接受的答案与此问题的标题无关。能改一下标题吗? 【参考方案1】:

我找到了一种使用 Scala 进行记忆的更好方法:

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() 
  override def apply(key: I) = getOrElseUpdate(key, f(key))

现在你可以这样写斐波那契:

lazy val fib: Int => BigInt = memoize 
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2)

这是一个带有多个参数的函数(选择函数):

lazy val c: ((Int, Int)) => BigInt = memoize 
  case (_, 0) => 1
  case (n, r) if r > n/2 => c(n, n - r)
  case (n, r) => c(n - 1, r - 1) + c(n - 1, r)

这是子集和问题:

// is there a subset of s which has sum = t
def isSubsetSumAchievable(s: Vector[Int], t: Int) = 
  // f is (i, j) => Boolean i.e. can the first i elements of s add up to j
  lazy val f: ((Int, Int)) => Boolean = memoize 
    case (_, 0) => true        // 0 can always be achieved using empty list
    case (0, _) => false       // we can never achieve non-zero if we have empty list
    case (i, j) => 
      val k = i - 1            // try the kth element
      f(k, j - s(k)) || f(k, j)
  
  f(s.length, t)

编辑:如下所述,这是一个线程安全的版本

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() self =>
  override def apply(key: I) = self.synchronized(getOrElseUpdate(key, f(key)))

【讨论】:

我不认为这(或我见过的大多数基于mutable.Map 的实现)是线程安全的?但如果在单线程上下文中使用,看起来语法不错。 我不确定可变 HashMap 实现是否真的会以某种方式崩溃和/或损坏数据,或者主要问题是否只是缺少更新;对于大多数用例来说,缺少更新可能是可以接受的。 @Gary Coady:如果你想要并发,用 HashMap 替换 HashMap 很简单 我想知道你是否可以在 TrieMap 上死锁。毕竟,地图是在 getOrElseUpdate 方法中“递归”访问的。 @pathikrit:我看不出使用 mutable.HashMap 的self.synchronized 版本有什么问题。我在这里的评论主要是对上面 cmets 中对TrieMap 的讨论的澄清,因为事实证明不可能简单地将TrieMap 子插入给定的代码。【参考方案2】:

类/特征级别val 编译为方法和私有变量的组合。因此允许递归定义。

另一方面,本地vals 只是常规变量,因此不允许递归定义。

顺便说一句,即使您定义的 def 有效,它也不会达到您的预期。在每次调用foo 时,都会创建一个新的函数对象fib,并且它会有自己的支持映射。相反,您应该做的是(如果您真的希望 def 成为您的公共界面):

private val fib: Memo[Int, BigInt] = Memo 
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2) 


def foo(n: Int) = 
  fib(n)
 

【讨论】:

'foo' 和 'fib' 只是一个简化 - 在我的情况下,foo 是子集和问题,而 fib 是输入集上的递归记忆,因此我不能简单地提取我的方法外的记忆函数。您能解释一下“类级 val 编译为方法和私有变量的组合”部分的含义吗?我应该注意类和方法vals 之间的其他区别? i) 是什么阻止您在方法之外提取它? ii) 当你在类/特质级别写val x = N 时,你得到的是def x = _xprivate val _x = N。你应该在任何 Scala 书籍中找到这个解释。我不记得字段vals 和本地vals 之间的任何其他差异。 即使在本地范围内也可以使用的解决方法:将fib 设为lazy val。然后你应该能够在本地范围内重复它。 如果它使用了可变状态和 val。这是否意味着它不是线程安全的? @ses,除非那个可变状态有线程安全保证。 (你可以是可变的和线程安全的。只是......更困难。)【参考方案3】:

Scalaz 有一个解决方案,为什么不重用它?

import scalaz.Memo
lazy val fib: Int => BigInt = Memo.mutableHashMapMemo 
  case 0 => 0
  case 1 => 1
  case n => fib(n-2) + fib(n-1)

你可以阅读更多关于memoization in Scalaz的信息。

【讨论】:

【参考方案4】:

可变 HashMap 不是线程安全的。同样为基本条件单独定义 case 语句似乎是不必要的特殊处理,而 Map 可以加载初始值并传递给 Memoizer。以下是 Memoizer 的签名,它接受一个备忘录(不可变映射)和公式并返回一个递归函数。

Memoizer 看起来像

def memoize[I,O](memo: Map[I, O], formula: (I => O, I) => O): I => O

现在给出以下斐波那契公式,

def fib(f: Int => Int, n: Int) = f(n-1) + f(n-2)

带有 Memoizer 的斐波那契可以定义为

val fibonacci = memoize( Map(0 -> 0, 1 -> 1), fib)

上下文无关的通用 Memoizer 被定义为

    def memoize[I, O](map: Map[I, O], formula: (I => O, I) => O): I => O = 
        var memo = map
        def recur(n: I): O = 
          if( memo contains n) 
            memo(n) 
           else 
            val result = formula(recur, n)
            memo += (n -> result)
            result
          
        
        recur
      

同样,对于阶乘,公式是

def fac(f: Int => Int, n: Int): Int = n * f(n-1)

Memoizer 的阶乘是

val factorial = memoize( Map(0 -> 1, 1 -> 1), fac)

灵感:记忆,Douglas Crockford 的 javascript 优秀部分的第 4 章

【讨论】:

> 为基本条件单独定义 case 语句似乎不必要的特殊处理 真的吗?实际上 fib 是具有简单基本案例的罕见示例之一。你会如何使用这个解决背包问题(github.com/pathikrit/scalgos/blob/master/src/main/scala/com/…)? 在斐波那契或任何预先知道值的情况下,应在地图中预加载。它使公式函数更接近其数学定义 IMO。如果公式需要比较(case 语句或 if...else 块),例如在解决背包问题时,使用 case 语句是完全可以的。【参考方案5】:

ZIO#cached 是一种在 ZIO 中进行记忆的方法

【讨论】:

以上是关于在 Scala 中是不是有一种通用的方式来记忆?的主要内容,如果未能解决你的问题,请参考以下文章

Scala 是不是有一种按列拆分 CSV 的好方法?

在 C++ 中是不是有一种最佳方式来运行指向该值的指针链?

标准 C++ 中是不是有一种可移植的方式来检索主机名?

Scala都认识?厉害了~

是否有一种“正确”的方式来读取 CSV 文件 [重复]

是否有一种“正确”的方式来读取 CSV 文件 [重复]