如何在缓存结果时更新非局部变量?

Posted

技术标签:

【中文标题】如何在缓存结果时更新非局部变量?【英文标题】:How to update nonlocal variables while caching results? 【发布时间】:2021-10-26 17:11:27 【问题描述】:

当使用像 lru_cache 这样的 functools 缓存函数时,内部函数不会更新非局部变量的值。同样的方法在没有装饰器的情况下也可以工作。

使用缓存装饰器时非局部变量是否没有更新?另外,如果我必须更新非局部变量但还要存储结果以避免重复工作,该怎么办?或者我是否需要从缓存函数中返回答案?

例如。以下没有正确更新非局部变量的值

def foo(x):
    outer_var=0

    @lru_cache
    def bar(i):
        nonlocal outer_var
        if condition:
            outer_var+=1
        else:
            bar(i+1)

    bar(x)
    return outer_var

背景

我正在尝试Decode Ways problem,它正在寻找可以将一串数字解释为字母的多种方式。我从第一个字母开始,采取一两个步骤并检查它们是否有效。到达字符串末尾时,我更新了一个非局部变量,该变量存储了可能的方式数。此方法在不使用 lru_cache 的情况下给出正确答案,但在使用缓存时失败。我返回值的另一种方法正在工作,但我想检查如何在使用记忆装饰器时更新非局部变量。

我的代码有错误:

ways=0
@lru_cache(None) # works well without this
def recurse(i):
    nonlocal ways
    if i==len(s):
        ways+=1
    elif i<len(s):
        if 1<=int(s[i])<=9:
            recurse(i+1)
        if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
            recurse(i+2)
    return 

recurse(0)
return ways

公认的解决方案:

@lru_cache(None)
def recurse(i):
    if i==len(s):
        return 1

    elif i<len(s):
        ans=0
        if 1<=int(s[i])<=9:
            ans+= recurse(i+1)
        if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
            ans+= recurse(i+2)
        return ans

return recurse(0)

【问题讨论】:

您能否提供一个最小的示例来特别强调缓存问题,而无需围绕它的整个“解码方式”算法? 嗨@Stef,感谢您的建议。我添加了最小的示例,但保留了原始问题的上下文。问题的症结在于更新函数外部的非局部变量,同时仍然缓存。缓存是使用 functools lru_cache 完成的,有助于避免再次计算相同参数的值。 【参考方案1】:

lru_cache 没有什么特别之处,nonlocal 变量或递归会导致任何固有问题,就其本身而言。这个问题纯粹是逻辑上的,而不是行为异常。看这个最小的例子:

from functools import lru_cache

def foo():
    c = 0

    @lru_cache(None)
    def bar(i=0):
        nonlocal c

        if i < 5:
            c += 1
            bar(i + 1)

    bar()
    return c

print(foo()) # => 5

解码方式代码的缓存版本中的问题是由于递归调用的重叠性质。缓存可防止基本案例调用 recurse(i) 其中 i == len(s) 多次执行,即使它是从不同的递归路径到达的。

确定这一点的一个好方法是在基本情况下(if i == len(s) 分支)打一个print("hello"),然后给它一个相当大的问题。你会看到print("hello") 触发一次,而且只有一次,并且由于ways 不能通过recurse(i)i == len(s) 之外的任何其他方式更新,所以当一切都说完之后,你只剩下ways == 1 .

在上面的玩具示例中,只有一个递归路径:每个 i 的调用在 0 和 9 之间扩展,并且从不使用缓存。相比之下,解码方式提供了多个递归路径,因此通过recurse(i+1) 的路径线性地找到基本情况,然后随着堆栈展开,recurse(i+2) 尝试找到其他方式来达到它。

添加缓存会切断额外的路径,但对于每个中间节点没有返回值。有了缓存,就好像你有一个子问题的记忆表或动态规划表,但你从不更新任何条目,所以整个表为零(基本情况除外)。

这是缓存导致的线性行为的示例:

from functools import lru_cache

def cached():
    @lru_cache(None)
    def cached_recurse(i=0):
        print("cached", i)

        if i < 3:
            cached_recurse(i + 1)
            cached_recurse(i + 2)

    cached_recurse()

def uncached():
    def uncached_recurse(i=0):
        print("uncached", i)

        if i < 3:
            uncached_recurse(i + 1)
            uncached_recurse(i + 2)

    uncached_recurse()

cached()
uncached()

输出:

cached 0
cached 1
cached 2
cached 3
cached 4
uncached 0
uncached 1
uncached 2
uncached 3
uncached 4
uncached 3
uncached 2
uncached 3
uncached 4

解决方案与您展示的完全一样:将结果向上传递并使用缓存来存储代表子问题的每个节点的值。这是两全其美:我们有子问题的值,但没有重新执行最终导致您的 ways += 1 基本情况的函数。

换句话说,如果您要使用缓存,请将其视为查找表,而不仅仅是调用树修剪器。在您的尝试中,它不记得做了什么工作,只是阻止它再次完成。

【讨论】:

以上是关于如何在缓存结果时更新非局部变量?的主要内容,如果未能解决你的问题,请参考以下文章

最终的局部变量不能分配,不能分配给非最终变量

自动化测试时需要使用python,请问如何理解python中的全局变量和局部变量?

Python:如何允许“内部函数”更改多个“外部函数”中的非局部变量

局部变量和全局变量

Python中的非局部变量

静态局部变量