如何在缓存结果时更新非局部变量?
Posted
技术标签:
【中文标题】如何在缓存结果时更新非局部变量?【英文标题】:How to update nonlocal variables while caching results? 【发布时间】:2021-10-26 17:11:27 【问题描述】:当使用像 lru_cache 这样的 functools 缓存函数时,内部函数不会更新非局部变量的值。同样的方法在没有装饰器的情况下也可以工作。
使用缓存装饰器时非局部变量是否没有更新?另外,如果我必须更新非局部变量但还要存储结果以避免重复工作,该怎么办?或者我是否需要从缓存函数中返回答案?
例如。以下没有正确更新非局部变量的值
def foo(x):
outer_var=0
@lru_cache
def bar(i):
nonlocal outer_var
if condition:
outer_var+=1
else:
bar(i+1)
bar(x)
return outer_var
背景
我正在尝试Decode Ways problem,它正在寻找可以将一串数字解释为字母的多种方式。我从第一个字母开始,采取一两个步骤并检查它们是否有效。到达字符串末尾时,我更新了一个非局部变量,该变量存储了可能的方式数。此方法在不使用 lru_cache 的情况下给出正确答案,但在使用缓存时失败。我返回值的另一种方法正在工作,但我想检查如何在使用记忆装饰器时更新非局部变量。
我的代码有错误:
ways=0
@lru_cache(None) # works well without this
def recurse(i):
nonlocal ways
if i==len(s):
ways+=1
elif i<len(s):
if 1<=int(s[i])<=9:
recurse(i+1)
if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
recurse(i+2)
return
recurse(0)
return ways
公认的解决方案:
@lru_cache(None)
def recurse(i):
if i==len(s):
return 1
elif i<len(s):
ans=0
if 1<=int(s[i])<=9:
ans+= recurse(i+1)
if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
ans+= recurse(i+2)
return ans
return recurse(0)
【问题讨论】:
您能否提供一个最小的示例来特别强调缓存问题,而无需围绕它的整个“解码方式”算法? 嗨@Stef,感谢您的建议。我添加了最小的示例,但保留了原始问题的上下文。问题的症结在于更新函数外部的非局部变量,同时仍然缓存。缓存是使用 functools lru_cache 完成的,有助于避免再次计算相同参数的值。 【参考方案1】:lru_cache
没有什么特别之处,nonlocal
变量或递归会导致任何固有问题,就其本身而言。这个问题纯粹是逻辑上的,而不是行为异常。看这个最小的例子:
from functools import lru_cache
def foo():
c = 0
@lru_cache(None)
def bar(i=0):
nonlocal c
if i < 5:
c += 1
bar(i + 1)
bar()
return c
print(foo()) # => 5
解码方式代码的缓存版本中的问题是由于递归调用的重叠性质。缓存可防止基本案例调用 recurse(i)
其中 i == len(s)
多次执行,即使它是从不同的递归路径到达的。
确定这一点的一个好方法是在基本情况下(if i == len(s)
分支)打一个print("hello")
,然后给它一个相当大的问题。你会看到print("hello")
触发一次,而且只有一次,并且由于ways
不能通过recurse(i)
和i == len(s)
之外的任何其他方式更新,所以当一切都说完之后,你只剩下ways == 1
.
在上面的玩具示例中,只有一个递归路径:每个 i
的调用在 0 和 9 之间扩展,并且从不使用缓存。相比之下,解码方式提供了多个递归路径,因此通过recurse(i+1)
的路径线性地找到基本情况,然后随着堆栈展开,recurse(i+2)
尝试找到其他方式来达到它。
添加缓存会切断额外的路径,但对于每个中间节点没有返回值。有了缓存,就好像你有一个子问题的记忆表或动态规划表,但你从不更新任何条目,所以整个表为零(基本情况除外)。
这是缓存导致的线性行为的示例:
from functools import lru_cache
def cached():
@lru_cache(None)
def cached_recurse(i=0):
print("cached", i)
if i < 3:
cached_recurse(i + 1)
cached_recurse(i + 2)
cached_recurse()
def uncached():
def uncached_recurse(i=0):
print("uncached", i)
if i < 3:
uncached_recurse(i + 1)
uncached_recurse(i + 2)
uncached_recurse()
cached()
uncached()
输出:
cached 0
cached 1
cached 2
cached 3
cached 4
uncached 0
uncached 1
uncached 2
uncached 3
uncached 4
uncached 3
uncached 2
uncached 3
uncached 4
解决方案与您展示的完全一样:将结果向上传递并使用缓存来存储代表子问题的每个节点的值。这是两全其美:我们有子问题的值,但没有重新执行最终导致您的 ways += 1
基本情况的函数。
换句话说,如果您要使用缓存,请将其视为查找表,而不仅仅是调用树修剪器。在您的尝试中,它不记得做了什么工作,只是阻止它再次完成。
【讨论】:
以上是关于如何在缓存结果时更新非局部变量?的主要内容,如果未能解决你的问题,请参考以下文章
自动化测试时需要使用python,请问如何理解python中的全局变量和局部变量?