在haskell中记忆多维递归解决方案
Posted
技术标签:
【中文标题】在haskell中记忆多维递归解决方案【英文标题】:Memoize multi-dimensional recursive solutions in haskell 【发布时间】:2022-01-19 09:35:10 【问题描述】:我在 haskell 中解决了一个递归问题,尽管我可以得到我想要缓存子问题输出的解决方案,因为它具有重叠的子问题属性。
问题是,给定一个维度为n*m
的网格和一个整数k
。从 (1, 1) 到达网格 (n, m) 且方向变化不超过 k 的方式有多少种。
这是没有记忆的代码
paths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
paths i j n m k dir
| i > n || j > m || k < 0 = 0
| i == n && j == m = 1
| dir == 0 = paths (i+1) j n m k 1 + paths i (j+1) n m k 2 -- is in grid (1,1)
| dir == 1 = paths (i+1) j n m k 1 + paths i (j+1) n m (k-1) 2 -- down was the direction took to reach here
| dir == 2 = paths (i+1) j n m (k-1) 1 + paths i (j+1) n m k 2 -- right was the direction took to reach here
| otherwise = -1
这里的因变量是i
、j
、k
、dir
。在像 c++/java 这样的语言中,可以使用 4-d DP 数组(dp[n][m][k][3]
,在 haskell 中我找不到实现它的方法。
【问题讨论】:
【参考方案1】:“打结”是一种众所周知的技术,它可以让 GHC 运行时为您记住结果,如果您提前知道所有需要查找的值的话。这个想法是将你的递归函数变成一个自引用的数据结构,然后简单地查找你真正关心的值。我为此选择使用 Array,但 Map 也可以。在任何一种情况下,您使用的数组或映射都必须是惰性/非严格的,因为我们将向其中插入在整个数组被填满之前我们还没有准备好计算的值。
import Data.Array (array, bounds, inRange, (!))
paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, 0)
where go (i, j, k, dir)
| i == m && j == n = 1
| dir == 1 = get (i+1, j, k, 1) + get (i, j+1, k-1, 2) -- down was the direction took to reach here
| dir == 2 = get (i+1, j, k-1, 1) + get (i, j+1, k, 2) -- right was the direction took to reach here
| otherwise = get (i+1, j, k, 1) + get (i, j+1, k, 2) -- is in grid (1,1)
a = array ((1, 1, 0, 1), (m, n, k, 2))
[(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [1..2]]
get x | inRange (bounds a) x = a ! x
| otherwise = 0
我稍微简化了你的 API:
m
和 n
参数不会随着每次迭代而改变,因此它们不应成为递归调用的一部分
客户端不必告诉您i
、j
和dir
的开头是什么,因此它们已从函数签名中删除,并隐式分别从 1、1 和 0 开始
我还调换了m
和n
的顺序,因为先使用n
参数很奇怪。这让我很头疼,因为我有一段时间没有注意到我还需要更改基本情况!
然后,正如我之前所说,我们的想法是用我们需要进行的所有递归调用来填充数组:这就是 array
调用。注意array
中的单元格是通过调用go
来初始化的,这(基本情况除外!)涉及调用get
,这涉及在数组中查找元素。这样,a
是自引用或递归的。但是我们不必决定以什么顺序查找,或者以什么顺序插入它们:我们足够懒惰,GHC 会根据需要评估数组元素。
我也有点厚脸皮,只在数组中为dir=1
和dir=2
留出空间,而不是dir=0
。我侥幸成功,因为dir=0
只在第一次调用时发生,我可以直接调用go
,绕过get
中的边界检查。这个技巧确实意味着如果你传递一个小于 1 的 m
或 n
或小于零的 k
,你会得到一个运行时错误。如果您需要处理这种情况,您可以为 paths
本身添加一个守卫。
当然,它确实有效:
> paths 3 3 2
4
您可以做的另一件事是为您的方向使用真实的数据类型,而不是 Int
:
import Data.Array (Ix, array, bounds, inRange, (!))
import Prelude hiding (Right)
data Direction = Neutral | Down | Right deriving (Eq, Ord, Ix)
paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, Neutral)
where go (i, j, k, dir)
| i == m && j == n = 1
| otherwise = case dir of
Neutral -> get (i+1, j, k, Down) + get (i, j+1, k, Right)
Down -> get (i+1, j, k, Down) + get (i, j+1, k-1, Right)
Right -> get (i+1, j, k-1, Down) + get (i, j+1, k, Right)
a = array ((1, 1, 0, Down), (m, n, k, Right))
[(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [Down, Right]]
get x | inRange (bounds a) x = a ! x
| otherwise = 0
(I 和 J 可能是比 Down 和 Right 更好的名字,我不知道这更容易还是更难记住)。我认为这可能是一种改进,因为这些类型现在有了更多的意义,而且你没有这个奇怪的otherwise
子句来处理像dir=7
这样应该是非法的事情。但它仍然有点不稳定,因为它依赖于枚举值的顺序:如果我们将 Neutral
放在 Down
和 Right
之间,它会中断。 (我尝试完全删除Neutral
方向并为第一步添加更多特殊情况,但这会以自己的方式变得丑陋)
【讨论】:
【参考方案2】:在 Haskell 中,这些事情确实不是最琐碎的事情。你真的很想进行一些就地突变来节省内存和时间,所以我认为没有比装备可怕的 ST
monad 更好的方法了。
这可以通过各种数据结构、数组、向量、repa 张量来完成。我从hashtables 中选择了HashTable
,因为它使用起来最简单,而且性能足以在我的示例中理解。
首先,介绍:
-# LANGUAGE Rank2Types #-
module Solution where
import Control.Monad.ST
import Control.Monad
import Data.HashTable.ST.Basic as HT
Rank2Types
在处理ST
时很有用,因为它是幻像类型。我选择了哈希表的Basic
变体,因为作者声称它具有最快的查找速度——而且我们会查找很多。
建议为地图使用类型别名,所以我们开始:
type Mem s = HT.HashTable s (Int, Int, Int, Int) Integer
无 ST 入口点只是为了创建地图并调用我们的怪物:
runpaths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
runpaths i j n m k dir = runST $ do
mem <- HT.new
paths mem i j n m k dir
这是paths
的记忆计算。我们只是尝试在地图中搜索结果,如果不存在则保存并返回:
mempaths mem i j n m k dir = do
res <- HT.lookup mem (i, j, k, dir)
case res of
Just x -> return x
Nothing -> do
x <- paths mem i j n m k dir
HT.insert mem (i, j, k, dir) x
return x
这里是算法的大脑。这只是一个使用带有记忆的调用来代替普通递归的单子动作:
paths mem i j n m k dir
| i > n || j > m || k < 0 = return 0
| i == n && j == m = return 1
| dir == 0 = do
x1 <- mempaths mem (i+1) j n m k 1
x2 <- mempaths mem i (j+1) n m k 2 -- is in grid (1,1)
return $ x1 + x2
| dir == 1 = do
x1 <- mempaths mem (i+1) j n m k 1
x2 <- mempaths mem i (j+1) n m (k-1) 2 -- down was the direction took to reach here
return $ x1 + x2
| dir == 2 = do
x1 <- mempaths mem (i+1) j n m (k-1) 1
x2 <- mempaths mem i (j+1) n m k 2 -- right was the direction took to reach here
return $ x1 + x2
| otherwise = return (-1)
【讨论】:
以上是关于在haskell中记忆多维递归解决方案的主要内容,如果未能解决你的问题,请参考以下文章