如何实现 numba jited 优先级队列?

Posted

技术标签:

【中文标题】如何实现 numba jited 优先级队列?【英文标题】:How can I implement a numba jitted priority queue? 【发布时间】:2021-07-30 05:26:06 【问题描述】:

我未能实现 numba jitted 优先级队列。

严重抄袭python docs,我对这门课还算满意。

import itertools

import numba as nb
from numba.experimental import jitclass
from typing import List, Tuple, Dict
from heapq import heappush, heappop


class PurePythonPriorityQueue:
    def __init__(self):
        self.pq = [] # list of entries arranged in a heap
        self.entry_finder =   # mapping of indices to entries
        self.REMOVED = -1 # placeholder for a removed item
        self.counter = itertools.count() # unique sequence count

    def put(self, item: Tuple[int, int], priority: float = 0.0):
        """Add a new item or update the priority of an existing item"""
        if item in self.entry_finder:
            self.remove_item(item)
        count = next(self.counter)
        entry = [priority, count, item]
        self.entry_finder[item] = entry
        heappush(self.pq, entry)

    def remove_item(self, item: Tuple[int, int]):
        """Mark an existing item as REMOVED.  Raise KeyError if not found."""
        entry = self.entry_finder.pop(item)
        entry[-1] = self.REMOVED

    def pop(self):
        """Remove and return the lowest priority item. Raise KeyError if empty."""
        while self.pq:
            priority, count, item = heappop(self.pq)
            if item is not self.REMOVED:
                del self.entry_finder[item]
                return item
        raise KeyError("pop from an empty priority queue")

现在我想从一个 numba jitted 函数中调用它来做繁重的数值工作,所以我试着把它变成一个 numba jitclass。由于在 vanilla python 实现中条目是异构列表,我想我也应该实现其他 jitclasses。但是,我收到了Failed in nopython mode pipeline (step: nopython frontend)(下面的完整跟踪)。

这是我的尝试:

@jitclass
class Item:
    i: int
    j: int

    def __init__(self, i, j):
        self.i = i
        self.j = j


@jitclass
class Entry:
    priority: float
    count: int
    item: Item
    removed: bool

    def __init__(self, p: float, c: int, i: Item):
        self.priority = p
        self.count = c
        self.item = i
        self.removed = False


@jitclass
class PriorityQueue:
    pq: List[Entry]
    entry_finder: Dict[Item, Entry]
    counter: int

    def __init__(self):
        self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
        self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
        self.counter = 0

    def put(self, item: Item, priority: float = 0.0):
        """Add a new item or update the priority of an existing item"""
        if item in self.entry_finder:
            self.remove_item(item)
        self.counter += 1
        entry = Entry(priority, self.counter, item)
        self.entry_finder[item] = entry
        heappush(self.pq, entry)

    def remove_item(self, item: Item):
        """Mark an existing item as REMOVED.  Raise KeyError if not found."""
        entry = self.entry_finder.pop(item)
        entry.removed = True

    def pop(self):
        """Remove and return the lowest priority item. Raise KeyError if empty."""
        while self.pq:
            priority, count, item = heappop(self.pq)
            entry = heappop(self.pq)
            if not entry.removed:
                del self.entry_finder[entry.item]
                return item
        raise KeyError("pop from an empty priority queue")


if __name__ == "__main__":
    queue1 = PurePythonPriorityQueue()
    queue1.put((4, 5), 5.4)
    queue1.put((5, 6), 1.0)
    print(queue1.pop())  # Yay this works!

    queue2 = PriorityQueue()  # Nope
    queue2.put(Item(4, 5), 5.4)
    queue2.put(Item(5, 6), 1.0)
    print(queue2.pop())

这种类型的数据结构可以用 numba 实现吗?我当前的实现有什么问题?

完整的跟踪:

(5, 6)
Traceback (most recent call last):
  File "/home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py", line 106, in <module>
    queue2 = PriorityQueue()  # Nope
  File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/experimental/jitclass/base.py", line 122, in __call__
    return cls._ctor(*bind.args[1:], **bind.kwargs)
  File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 361, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function typeddict_empty at 0x7fead8c3f8b0>) found for signature:

 >>> typeddict_empty(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'typeddict_empty': File: numba/typed/typeddict.py: Line 213.
    With argument(s): '(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
   Rejected as the implementation raised a specific error:
     TypingError: Failed in nopython mode pipeline (step: nopython frontend)
   No implementation of function Function(<function new_dict at 0x7fead9002a60>) found for signature:

    >>> new_dict(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)

   There are 2 candidate implementations:
         - Of which 2 did not match due to:
         Overload in function 'impl_new_dict': File: numba/typed/dictobject.py: Line 639.
           With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
          Rejected as the implementation raised a specific error:
            TypingError: Failed in nopython mode pipeline (step: nopython mode backend)
          No implementation of function Function(<built-in function eq>) found for signature:

           >>> eq(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)

          There are 30 candidate implementations:
                - Of which 28 did not match due to:
                Overload of function 'eq': File: <numerous>: Line N/A.
                  With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
                 No match.
                - Of which 2 did not match due to:
                Operator Overload in function 'eq': File: unknown: Line unknown.
                  With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
                 No match for registered cases:
                  * (bool, bool) -> bool
                  * (int8, int8) -> bool
                  * (int16, int16) -> bool
                  * (int32, int32) -> bool
                  * (int64, int64) -> bool
                  * (uint8, uint8) -> bool
                  * (uint16, uint16) -> bool
                  * (uint32, uint32) -> bool
                  * (uint64, uint64) -> bool
                  * (float32, float32) -> bool
                  * (float64, float64) -> bool
                  * (complex64, complex64) -> bool
                  * (complex128, complex128) -> bool

          During: lowering "$20call_function.8 = call $12load_global.4(dp, $16load_deref.6, $18load_deref.7, func=$12load_global.4, args=[Var(dp, dictobject.py:653), Var($16load_deref.6, dictobject.py:654), Var($18load_deref.7, dictobject.py:654)], kws=(), vararg=None)" at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/dictobject.py (654)
     raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/types/functions.py:229

   During: resolving callee type: Function(<function new_dict at 0x7fead9002a60>)
   During: typing of call at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py (219)


   File "../../../../../.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py", line 219:
       def impl(cls, key_type, value_type):
           return dictobject.new_dict(key_type, value_type)
           ^

  raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/typeinfer.py:1071

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.abstract.TypeRef'>, 'empty') for typeref[<class 'numba.core.types.containers.DictType'>])
During: typing of call at /home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py (72)


File "priorityqueue.py", line 72:
    def __init__(self):
        <source elided>
        self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
        self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
        ^

During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)

During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>


Process finished with exit code 1

【问题讨论】:

【参考方案1】:

由于 numba 中的几个问题,这是不可能的,但如果我理解正确,应该在下一个版本 (0.55) 中修复。作为目前的解决方法,我可以通过编译 llvmlite 0.38.0dev0 和 numba 的主分支来使其工作。我不使用 conda,但通过这种方式获得 llvmlite 和 numba 的预发布显然更容易。

这是我的实现:

from heapq import heappush, heappop
from typing import List, Tuple, Dict, Any

import numba as nb
import numpy as np
from numba.experimental import jitclass


class UpdatablePriorityQueueEntry:
    def __init__(self, p: float, i: Any):
        self.priority = p
        self.item = i

    def __lt__(self, other: "UpdatablePriorityQueueEntry"):
        return self.priority < other.priority


class UpdatablePriorityQueue:
    def __init__(self):
        self.pq = []
        self.entries_priority = 

    def put(self, item: Any, priority: float = 0.0):
        entry = UpdatablePriorityQueueEntry(priority, item)
        self.entries_priority[item] = priority
        heappush(self.pq, entry)

    def pop(self) -> Any:
        while self.pq:
            entry = heappop(self.pq)
            if entry.priority == self.entries_priority[entry.item]:
                self.entries_priority[entry.item] = np.inf
                return entry.item
        raise KeyError("pop from an empty priority queue")

    def clear(self):
        self.pq.clear()
        self.entries_priority.clear()


@jitclass
class PriorityQueueEntry(UpdatablePriorityQueueEntry):
    priority: float
    item: Tuple[int, int]

    def __init__(self, p: float, i: Tuple[int, int]):
        self.priority = p
        self.item = i


@jitclass
class UpdatablePriorityQueue(UpdatablePriorityQueue):
    pq: List[PriorityQueueEntry2d]
    entries_priority: Dict[Tuple[int, int], float]

    def __init__(self):
        self.pq = nb.typed.List.empty_list(PriorityQueueEntry2d(0.0, (0, 0)))
        self.entries_priority = nb.typed.Dict.empty((0, 0), 0.0)

    def put(self, item: Tuple[int, int], priority: float = 0.0):
        entry = PriorityQueueEntry2d(priority, item)
        self.entries_priority[item] = priority
        heappush(self.pq, entry)

【讨论】:

【参考方案2】:

我遇到了与自定义类 Entry 相关的类似问题。基本上 Numba 无法使用 __lt__(self, other) 来比较条目,并给了我一个 No implementation of function Function(&lt; built-in function lt &gt;) 错误。

所以我想出了以下内容。它适用于 Ubuntu 18.04 上 Python 3.8 上的 Numba 0.55.1。诀窍是避免使用任何自定义类对象作为优先队列项的一部分,以避免上述错误。

from typing import List, Dict, Tuple 
from heapq import heappush, heappop
import numba as nb
from numba.experimental import jitclass

# priority, counter, item, removed
entry_def = (0.0, 0, (0,0), nb.typed.List([False]))
entry_type = nb.typeof(entry_def)

@jitclass
class PriorityQueue:
    # The following helps numba infer type of variable
    pq: List[entry_type]
    entry_finder: Dict[Tuple[int, int], entry_type]
    counter: int
    entry: entry_type

    def __init__(self):
        # Must declare types here see https://numba.pydata.org/numba-doc/dev/reference/pysupported.html
        self.pq = nb.typed.List.empty_list((0.0, 0, (0,0), nb.typed.List([False])))
        self.entry_finder = nb.typed.Dict.empty( (0, 0), (0.0, 0, (0,0), nb.typed.List([False])))
        self.counter = 0

    def put(self, item: Tuple[int, int], priority: float = 0.0):
        """Add a new item or update the priority of an existing item"""
        if item in self.entry_finder:
            # Mark duplicate item for deletion
            self.remove_item(item)
    
        self.counter += 1
        entry = (priority, self.counter, item, nb.typed.List([False]))
        self.entry_finder[item] = entry
        heappush(self.pq, entry)

    def remove_item(self, item: Tuple[int, int]):
        """Mark an existing item as REMOVED via True.  Raise KeyError if not found."""
        self.entry = self.entry_finder.pop(item)
        self.entry[3][0] = True
    
    def pop(self):
        """Remove and return the lowest priority item. Raise KeyError if empty."""
        while self.pq:
            priority, count, item, removed = heappop(self.pq)
            if not removed[0]:
                del self.entry_finder[item]
                return priority, item
        raise KeyError("pop from an empty priority queue")

首先定义一个名为entry_def 的全局变量,它将作为我们优先级队列pq 中的条目。 “已删除”的标记现在被替换为 numba.typed.List([False]),以作为在优先键更改(延迟删除)的情况下跟踪要删除的项目的一种方式。烦人的部分是必须输入pqentry_finder 的定义;我无法重用 entry_def 变量。

我可以确认PriorityQueue 的工作方式如下:

    q = PriorityQueue()
    q.put((1,1), 5.0)
    q.put((1,1), 4.0)
    q.put((1,1), 3.0)
    q.put((1,1), 6.0)
    print(q.pq)
    >>  [(3.0, 3, (1, 1), ListType[bool]([True])), (5.0, 1, (1, 1), ListType[bool]([True])), (4.0, 2, (1, 1), ListType[bool]([True])), (6.0, 4, (1, 1), ListType[bool]([False]))]
    print(q.pop())
    >> (6.0, (1, 1))
    print(len(q.entry_finder))
    >> 0

希望有人会发现这很有用或可以提供更好的选择。

【讨论】:

以上是关于如何实现 numba jited 优先级队列?的主要内容,如果未能解决你的问题,请参考以下文章

如何在实现为二进制堆的优先级队列中保留相同优先级元素的顺序?

STL优先级队列剖析及模拟实现

如何在 Python 中实现优先级队列?

如何避免优先级队列中的饥饿

具有动态项目优先级的优先级队列

如何防止 ActiveMQ 优先队列上的低优先级消息被饿死?