lru_cache分析

Posted yerikyu

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了lru_cache分析相关的知识,希望对你有一定的参考价值。

缓存

在计算机软件领域,缓存(Cache)指的是将部分数据存储在内存中,以便下次能够更快地访问这些数据,这也是一个典型的用空间换时间的例子。一般用于缓存的内存空间是固定的,当有更多的数据需要缓存的时候,需要将已缓存的部分数据清除后再将新的缓存数据放进去。需要清除哪些数据,就涉及到了缓存置换的策略,LRU(Least Recently Used,最近最少使用)是很常见的一个,也是 Python 中提供的缓存置换策略。

这个一种将定量数据加以保存以备迎合后续获取需求的处理方式,旨在加快数据获取的速度。数据的生成过程可能需要经过计算,规整,远程获取等操作,如果是同一份数据需要多次使用,每次都重新生成会大大浪费时间。所以,如果将计算或者远程请求等操作获得的数据缓存下来,会加快后续的数据获取需求。

什么是lru

计算机的缓存容量毕竟有限,如果缓存满了就要删除一些内容,给新内容腾位置。但问题是,删除哪些内容呢?我们肯定希望删掉哪些没什么用的缓存,而把有用的数据继续留在缓存里,方便之后继续使用。那么,什么样的数据,我们判定为有用的的数据呢?lru(即least recently used的缩写),即最近最少使用原则。表明缓存不会无限制增长,一段时间不用的缓存条目会被扔掉,这些有段时间没用到的数据就不是有用数据。

举个简单的例子,我先后打开了支付宝、微信、有道云笔记、qq之后,在后台排列的顺序就是

如果我这个时候访问支付宝,顺序就会变成

假设我的手机只允许运行三个程序,这个是时候如果我想再运行一下百度云,那么按照lru算法的执行顺序,就应该关闭"QQ"为百度云盘腾出运行的空间,毕竟QQ是最久未使用的应用,然后新开的百度云盘会被放在新开应用的最上面

现在应该可以稍微理解LRU策略了。当然还有其他缓存淘汰策略,比如不要按访问的时序来淘汰,而是按访问频率(LFU 策略)来淘汰等等,各有应用场景。这次主要是要来分析lru在python中的使用

算法描述

我们不妨借助146. LRU 缓存机制这道题来辅助理解。出题人就是想让我们设计足以实现lru的数据结构:

  • 要有一个接收capacity参数作为缓存的最大容量,
  • 实现两个API
    • put(key, val)方法存入键值对
    • get(key) 方法获取key对应的val,如果key不存在则返回 -1。

注意!按照出题者的期望,getput方法最好都是O(1) 的时间复杂度,我们举个具体例子来看看 LRU 算法怎么工作。

LRUCache lRUCache = new LRUCache(2);
lRUCache.put(1, 1); // 缓存是 {1=1}
lRUCache.put(2, 2); // 缓存是 {1=1, 2=2}, 因为最近访问了键 1, 所以提前至队头, 返回键 1 对应的值 1
lRUCache.get(1);    // 返回 1
lRUCache.put(3, 3); // 该操作会使得关键字 2 作废,缓存是 {1=1, 3=3}, 缓存容量已满,需要删除内容空出位置, 优先删除久未使用的数据,也就是队尾的数据, 然后把新的数据插入队头
lRUCache.get(2);    // 返回 -1 (未找到)
lRUCache.put(4, 4); // 该操作会使得关键字 1 作废,缓存是 {4=4, 3=3}, 键 1 已存在,把原始值 1 覆盖为 4, 需要注意的是要将键值对提前到队头 
lRUCache.get(1);    // 返回 -1 (未找到)
lRUCache.get(3);    // 返回 3
lRUCache.get(4);    // 返回 4

设计实现

分析上面的操作过程,要让putget方法的时间复杂度为O(1),我们可以总结出cache这个数据结构必要的条件:

  1. cache中的元素必须有时序,以区分最近使用的和最久未使用的数据,当容量满了之后要删除最久未使用的那个元素。由于存在顺序之分是,最常见的方式是通过链表或者数组实现
  2. 我们要在cache中快速找某个key是否已存在并得到对应的val,由于期望是在O(1)复杂度内这个过程就需要来通过哈希表来实现。
  3. 每次访问cache中的某个key,需要将这个元素变为最近使用的,也就是说cache要支持在任意位置快速插入和删除元素。可以在任意位置快速插入,我第一反应是通过双向链表来实现效果。

那么,什么数据结构同时符合上述条件呢?哈希表查找快,但是数据无固定顺序;链表有顺序之分,插入删除快,但是查找慢。所以结合一下,形成一种新的数据结构:哈希链表LinkedHashMap

LRU 缓存算法的核心数据结构就是哈希链表,即双向链表和哈希表的结合体。这个数据结构长这样:

借助这个结构,我们来逐一分析上面的 3 个条件:

  1. 如果我们每次默认从链表头部添加元素,那么显然越靠头部的元素就是最近使用的,越靠尾部的元素就是最久未使用的。
  2. 对于某一个 key,我们可以通过哈希表快速定位到链表中的节点,从而取得对应 val。
  3. 链表显然是支持在任意位置快速插入和删除的,改改指针就行。只不过传统的链表无法按照索引快速访问某一个位置的元素,而这里借助哈希表,可以通过 key 快速映射到任意一个链表节点,然后进行插入和删除。

代码实现

python已经有相关的实现如lru_cache。毕竟实践出真知,我们或许需要自己的来实现一遍之后才会有直观的感受和理解算法的细节,那就自己来造个轮子吧:)

首先,我们把双链表的节点类DLinkedNode写出来,为了简化,keyval都认为是整型:

class DLinkedNode:
    def __init__(self, key=0, value=0):
        self.key = key
        self.value = value
        self.prev = None
        self.next = None

然后依靠我们的Node类型构建一个双链表,实现几个双链表必须的 API:

  • addToHead: 添加最近使用的元素
  • removeNode: 删除某一个key
  • moveToHead: 将某个key提升为最近使用的元素
  • removeTail: 删除最久未使用的元素

    class LRUCache:
    ...
    def addToHead(self, node):
        node.prev = self.head
        node.next = self.head.next
        self.head.next.prev = node
        self.head.next = node
    
    def removeNode(self, node):
        node.prev.next = node.next
        node.next.prev = node.prev
    
    def moveToHead(self, node):
        self.removeNode(node)
        self.addToHead(node)
    
    def removeTail(self):
        node = self.tail.prev
        self.removeNode(node)
        return node   

    之后我们基于以上这四个API来实现哈希表的两个基本API:

  • get:实现起来方便一些
    • 判断key在不在哈希表中,如果不在返回-1
    • 反之,返回对应的node,之后将该node提升到head
  • put
    • 判断key在不在表中,如果已经存在,则提升对应的nodehead
    • 如果不存在则需要插入新的key,这个时候又要做判断
      • 容量未满,直接加入到头部
      • 容量已满,删除尾部节点之后再插入到头部
class LRUCache:
    ...
    def get(self, key: int) -> int:
        if key not in self.cache:
            return -1
        # 如果 key 存在,先通过哈希表定位,再移到头部
        node = self.cache[key]
        self.moveToHead(node)
        return node.value

    def put(self, key: int, value: int) -> None:
        if key not in self.cache:
            # 如果 key 不存在,创建一个新的节点
            node = DLinkedNode(key, value)
            # 添加进哈希表
            self.cache[key] = node
            # 添加至双向链表的头部
            self.addToHead(node)
            self.size += 1
            if self.size > self.capacity:
                # 如果超出容量,删除双向链表的尾部节点
                removed = self.removeTail()
                # 删除哈希表中对应的项
                self.cache.pop(removed.key)
                self.size -= 1
        else:
            # 如果 key 存在,先通过哈希表定位,再修改 value,并移到头部
            node = self.cache[key]
            node.value = value
            self.moveToHead(node)
    ...

之后我们再根据需要来写一下类初始化方法

class LRUCache:
    def __init__(self, capacity: int):
        self.cache = dict()
        # 使用伪头部和伪尾部节点  
        self.head = DLinkedNode()
        self.tail = DLinkedNode()
        self.head.next = self.tail
        self.tail.prev = self.head
        self.capacity = capacity
        self.size = 0

好的,我们把框架搭建起来之后,这个类的功能也就大致实现了,学废了嘛:)

装饰器@lru_cache介绍

啰啰嗦嗦说了这么多,我们看看在python中提供了自己的缓存工具functools.lru_cache(),这个装饰器实际上就是替我们实现lru的功能我们需要的时候直接用装饰器加载即可

@functools.lru_cache(user_function)
@functools.lru_cache(maxsize=128, typed=False)

maxsize保存最近多少个调用的结果,最好设置为 2 的倍数,默认为 128。如果设置为 None 的话就相当于是maxsize为正无穷了。还有一个参数是 type,如果 type 设置为 true,即把不同参数类型得到的结果分开保存,如 f(3) 和 f(3.0) 会被区分开。

源码分析

看看 Python 内部是怎么实现 lru_cache 的。写作时 Python 最新发行版是 3.9,所以这里使用的是Python 3.9的源码,并且保留了源码中的注释。

def lru_cache(maxsize=128, typed=False):
    """Least-recently-used cache decorator.

    If *maxsize* is set to None, the LRU features are disabled and the cache
    can grow without bound.

    If *typed* is True, arguments of different types will be cached separately.
    For example, f(3.0) and f(3) will be treated as distinct calls with
    distinct results.

    Arguments to the cached function must be hashable.

    View the cache statistics named tuple (hits, misses, maxsize, currsize)
    with f.cache_info().  Clear the cache and statistics with f.cache_clear().
    Access the underlying function with f.__wrapped__.

    See:  http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)

    """

    # Users should only access the lru_cache through its public API:
    #       cache_info, cache_clear, and f.__wrapped__
    # The internals of the lru_cache are encapsulated for thread safety and
    # to allow the implementation to change (including a possible C version).

    if isinstance(maxsize, int):
        # Negative maxsize is treated as 0
        if maxsize < 0:
            maxsize = 0
    elif callable(maxsize) and isinstance(typed, bool):
        # The user_function was passed in directly via the maxsize argument
        user_function, maxsize = maxsize, 128
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {\'maxsize\': maxsize, \'typed\': typed}
        return update_wrapper(wrapper, user_function)
    elif maxsize is not None:
        raise TypeError(
            \'Expected first argument to be an integer, a callable, or None\')

    def decorating_function(user_function):
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {\'maxsize\': maxsize, \'typed\': typed}
        return update_wrapper(wrapper, user_function)

    return decorating_function

这段代码中有如下几个关键点:

  • 关键字参数
    • maxsize表示缓存容量,如果为None表示容量不设限, typed表示是否区分参数类型,注释中也给出了解释,如果typed == True,那么f(3)f(3.0)会被认为是不同的函数调用。
  • 第 507 行的条件分支
    • 如果lru_cache的第一个参数是可调用的,直接返回wrapper,也就是把lru_cache当做不带参数的装饰器,这是 Python 3.8 才有的特性,也就是说在 Python 3.8 及之后的版本中我们可以用下面的方式使用lru_cache,可能是为了防止程序员在使用lru_cache的时候忘记加括号。【一会我们来实践一下】

lru_cache的具体逻辑是在_lru_cache_wrapper函数中实现的

def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
    # Constants shared by all lru cache instances:
    sentinel = object()          # unique object used to signal cache misses
    make_key = _make_key         # build a key from the function arguments
    PREV, NEXT, KEY, RESULT = 0, 1, 2, 3   # names for the link fields

    cache = {}
    hits = misses = 0
    full = False
    cache_get = cache.get    # bound method to lookup a key or return None
    cache_len = cache.__len__  # get cache size without calling len()
    lock = RLock()           # because linkedlist updates aren\'t threadsafe
    root = []                # root of the circular doubly linked list
    root[:] = [root, root, None, None]     # initialize by pointing to self

    if maxsize == 0:

        def wrapper(*args, **kwds):
            # No caching -- just a statistics update
            nonlocal misses
            misses += 1
            result = user_function(*args, **kwds)
            return result

    elif maxsize is None:

        def wrapper(*args, **kwds):
            # Simple caching without ordering or size limit
            nonlocal hits, misses
            key = make_key(args, kwds, typed)
            result = cache_get(key, sentinel)
            if result is not sentinel:
                hits += 1
                return result
            misses += 1
            result = user_function(*args, **kwds)
            cache[key] = result
            return result

    else:

        def wrapper(*args, **kwds):
            # Size limited caching that tracks accesses by recency
            nonlocal root, hits, misses, full
            key = make_key(args, kwds, typed)
            with lock:
                link = cache_get(key)
                if link is not None:
                    # Move the link to the front of the circular queue
                    link_prev, link_next, _key, result = link
                    link_prev[NEXT] = link_next
                    link_next[PREV] = link_prev
                    last = root[PREV]
                    last[NEXT] = root[PREV] = link
                    link[PREV] = last
                    link[NEXT] = root
                    hits += 1
                    return result
                misses += 1
            result = user_function(*args, **kwds)
            with lock:
                if key in cache:
                    # Getting here means that this same key was added to the
                    # cache while the lock was released.  Since the link
                    # update is already done, we need only return the
                    # computed result and update the count of misses.
                    pass
                elif full:
                    # Use the old root to store the new key and result.
                    oldroot = root
                    oldroot[KEY] = key
                    oldroot[RESULT] = result
                    # Empty the oldest link and make it the new root.
                    # Keep a reference to the old key and old result to
                    # prevent their ref counts from going to zero during the
                    # update. That will prevent potentially arbitrary object
                    # clean-up code (i.e. __del__) from running while we\'re
                    # still adjusting the links.
                    root = oldroot[NEXT]
                    oldkey = root[KEY]
                    oldresult = root[RESULT]
                    root[KEY] = root[RESULT] = None
                    # Now update the cache dictionary.
                    del cache[oldkey]
                    # Save the potentially reentrant cache[key] assignment
                    # for last, after the root and links have been put in
                    # a consistent state.
                    cache[key] = oldroot
                else:
                    # Put result in a new link at the front of the queue.
                    last = root[PREV]
                    link = [last, root, key, result]
                    last[NEXT] = root[PREV] = cache[key] = link
                    # Use the cache_len bound method instead of the len() function
                    # which could potentially be wrapped in an lru_cache itself.
                    full = (cache_len() >= maxsize)
            return result

    def cache_info():
        """Report cache statistics"""
        with lock:
            return _CacheInfo(hits, misses, maxsize, cache_len())

    def cache_clear():
        """Clear the cache and cache statistics"""
        nonlocal hits, misses, full
        with lock:
            cache.clear()
            root[:] = [root, root, None, None]
            hits = misses = 0
            full = False

    wrapper.cache_info = cache_info
    wrapper.cache_clear = cache_clear
    return wrapper

函数在 530-537 行定义了一些关键变量:

  • hitsmisses分别表示缓存命中和没有命中的次数
  • root双向循环链表的头结点,每个节点保存前向指针、后向指针、keykey对应的result,其中key_make_key函数根据参数结算出来的字符串,result为被修饰的函数在给定的参数下返回的结果,也就是我们自己设计的时候keyvalue。注意,root是不保存数据keyresult的。
  • cache是真正保存缓存数据的地方,类型为dictcache中的key也是_make_key函数根据参数结算出来的字符串,value保存的是key对应的双向循环链表中的节点。

接下来根据maxsize不同,定义不同的wrapper

  • maxsize == 0,其实也就是没有缓存,那么每次函数调用都不会命中,并且没有对命中的次数misses加 1。
  • maxsize is None,不限制缓存大小,如果函数调用不命中,将没有命中次数misses加 1,否则将命中次数hits加 1。
  • 限制缓存的大小,那么需要根据 LRU 算法来更新cache,也就是 565~620 行的代码。
  • 如果缓存命中key,那么将命中节点移到双向循环链表的结尾,并且返回结果(571~581 行)这里通过字典加双向循环链表的组合数据结构,实现了用O(1)的时间复杂度删除给定的节点。
  • 如果没有命中,并且缓存满了,那么需要将最久没有使用的节点(root 的下一个节点)删除,并且将新的节点添加到链表结尾。在实现中有一个优化,直接将当前的rootkeyresult 替换成新的值,将root 的下一个节点置为新的root,这样得到的双向循环链表结构跟删除root的下一个节点并且将新节点加到链表结尾是一样的,但是避免了删除和添加节点的操作(591~611 行)
  • 如果没有命中,并且缓存没满,那么直接将新节点添加到双向循环链表的结尾(root[PREV])(613~619 行)

性能测试

我们以斐波拉契数的计算为例,来感受一下使用缓存和不使用之间的区别,注意比较执行时间的不同

from time import time

def factorial(n):
    print(f"计算 {n} 的阶乘")
    return 1 if n <= 1 else n * factorial(n - 1)

start = time()
a = factorial(10)
print(f\'10! = {a}\')
b = factorial(8)
print(f\'8! = {b}\')
end = time()
print("耗时:", end - start, \'secs\')

from functools import lru_cache
from time import time

@lru_cache()
def factorial(n):
    print(f"计算 {n} 的阶乘")
    return 1 if n <= 1 else n * factorial(n - 1)

start = time()
a = factorial(10)
print(f\'10! = {a}\')
b = factorial(8)
print(f\'8! = {b}\')
end = time()
print("耗时:", end - start, \'secs\')

如果我们计算11的阶乘呢?这个时候由于在缓存中查找不到key=11的值,因此仅需要计算11的阶乘

from functools import lru_cache
from time import time

@lru_cache()
def factorial(n):
    print(f"计算 {n} 的阶乘")
    return 1 if n <= 1 else n * factorial(n - 1)

start = time()
a = factorial(10)
print(f\'10! = {a}\')
b = factorial(8)
print(f\'8! = {b}\')
end = time()
print("耗时:", end - start, \'secs\')
b = factorial(11)
print(f\'11! = {b}\')

我们可以看到由于缓存的引入,我们在计算的过程中遇到曾经得出的结果就可以通过这种方式对过往结果进行服用,从而提升了计算效率。

@lru_cache的使用

实践出真知,我们先在算法练习中对其进行相关的性能测试,以70. 爬楼梯这道题为例,我们来体验一下展示使用装饰器@lru_cache的效果

class Solution:
    def climbStairs(self, n: int) -> int:
        if n == 1:
            return 1
        elif n == 2:
            return 2
        else:
            return self.climbStairs(n-1) + self.climbStairs(n-2)

加入装饰器@lru_cache

class Solution:
    @lru_cache
    def climbStairs(self, n: int) -> int:
        if n == 1:
            return 1
        elif n == 2:
            return 2
        else:
            return self.climbStairs(n-1) + self.climbStairs(n-2)

参考资料:

  1. LRU算法
  2. https://www.cnblogs.com/zikcheng/p/14322577.html

以上是关于lru_cache分析的主要内容,如果未能解决你的问题,请参考以下文章

Lru_cache(来自 functools)如何工作?

使用缓存方式优化递归函数与lru_cache

从内部函数禁用`functools.lru_cache`

如何用 `functools.lru_cache` 正确装饰`classmethod`?

functools.lru_cache的实现

AttributeError:“模块”对象没有属性“lru_cache”