如何实现 numba jited 优先级队列?
Posted
技术标签:
【中文标题】如何实现 numba jited 优先级队列?【英文标题】:How can I implement a numba jitted priority queue? 【发布时间】:2021-07-30 05:26:06 【问题描述】:我未能实现 numba jitted 优先级队列。
严重抄袭python docs,我对这门课还算满意。
import itertools
import numba as nb
from numba.experimental import jitclass
from typing import List, Tuple, Dict
from heapq import heappush, heappop
class PurePythonPriorityQueue:
def __init__(self):
self.pq = [] # list of entries arranged in a heap
self.entry_finder = # mapping of indices to entries
self.REMOVED = -1 # placeholder for a removed item
self.counter = itertools.count() # unique sequence count
def put(self, item: Tuple[int, int], priority: float = 0.0):
"""Add a new item or update the priority of an existing item"""
if item in self.entry_finder:
self.remove_item(item)
count = next(self.counter)
entry = [priority, count, item]
self.entry_finder[item] = entry
heappush(self.pq, entry)
def remove_item(self, item: Tuple[int, int]):
"""Mark an existing item as REMOVED. Raise KeyError if not found."""
entry = self.entry_finder.pop(item)
entry[-1] = self.REMOVED
def pop(self):
"""Remove and return the lowest priority item. Raise KeyError if empty."""
while self.pq:
priority, count, item = heappop(self.pq)
if item is not self.REMOVED:
del self.entry_finder[item]
return item
raise KeyError("pop from an empty priority queue")
现在我想从一个 numba jitted 函数中调用它来做繁重的数值工作,所以我试着把它变成一个 numba jitclass。由于在 vanilla python 实现中条目是异构列表,我想我也应该实现其他 jitclasses。但是,我收到了Failed in nopython mode pipeline (step: nopython frontend)
(下面的完整跟踪)。
这是我的尝试:
@jitclass
class Item:
i: int
j: int
def __init__(self, i, j):
self.i = i
self.j = j
@jitclass
class Entry:
priority: float
count: int
item: Item
removed: bool
def __init__(self, p: float, c: int, i: Item):
self.priority = p
self.count = c
self.item = i
self.removed = False
@jitclass
class PriorityQueue:
pq: List[Entry]
entry_finder: Dict[Item, Entry]
counter: int
def __init__(self):
self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
self.counter = 0
def put(self, item: Item, priority: float = 0.0):
"""Add a new item or update the priority of an existing item"""
if item in self.entry_finder:
self.remove_item(item)
self.counter += 1
entry = Entry(priority, self.counter, item)
self.entry_finder[item] = entry
heappush(self.pq, entry)
def remove_item(self, item: Item):
"""Mark an existing item as REMOVED. Raise KeyError if not found."""
entry = self.entry_finder.pop(item)
entry.removed = True
def pop(self):
"""Remove and return the lowest priority item. Raise KeyError if empty."""
while self.pq:
priority, count, item = heappop(self.pq)
entry = heappop(self.pq)
if not entry.removed:
del self.entry_finder[entry.item]
return item
raise KeyError("pop from an empty priority queue")
if __name__ == "__main__":
queue1 = PurePythonPriorityQueue()
queue1.put((4, 5), 5.4)
queue1.put((5, 6), 1.0)
print(queue1.pop()) # Yay this works!
queue2 = PriorityQueue() # Nope
queue2.put(Item(4, 5), 5.4)
queue2.put(Item(5, 6), 1.0)
print(queue2.pop())
这种类型的数据结构可以用 numba 实现吗?我当前的实现有什么问题?
完整的跟踪:
(5, 6)
Traceback (most recent call last):
File "/home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py", line 106, in <module>
queue2 = PriorityQueue() # Nope
File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/experimental/jitclass/base.py", line 122, in __call__
return cls._ctor(*bind.args[1:], **bind.kwargs)
File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
error_rewrite(e, 'typing')
File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 361, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function typeddict_empty at 0x7fead8c3f8b0>) found for signature:
>>> typeddict_empty(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'typeddict_empty': File: numba/typed/typeddict.py: Line 213.
With argument(s): '(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function new_dict at 0x7fead9002a60>) found for signature:
>>> new_dict(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'impl_new_dict': File: numba/typed/dictobject.py: Line 639.
With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython mode backend)
No implementation of function Function(<built-in function eq>) found for signature:
>>> eq(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)
There are 30 candidate implementations:
- Of which 28 did not match due to:
Overload of function 'eq': File: <numerous>: Line N/A.
With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
No match.
- Of which 2 did not match due to:
Operator Overload in function 'eq': File: unknown: Line unknown.
With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
No match for registered cases:
* (bool, bool) -> bool
* (int8, int8) -> bool
* (int16, int16) -> bool
* (int32, int32) -> bool
* (int64, int64) -> bool
* (uint8, uint8) -> bool
* (uint16, uint16) -> bool
* (uint32, uint32) -> bool
* (uint64, uint64) -> bool
* (float32, float32) -> bool
* (float64, float64) -> bool
* (complex64, complex64) -> bool
* (complex128, complex128) -> bool
During: lowering "$20call_function.8 = call $12load_global.4(dp, $16load_deref.6, $18load_deref.7, func=$12load_global.4, args=[Var(dp, dictobject.py:653), Var($16load_deref.6, dictobject.py:654), Var($18load_deref.7, dictobject.py:654)], kws=(), vararg=None)" at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/dictobject.py (654)
raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/types/functions.py:229
During: resolving callee type: Function(<function new_dict at 0x7fead9002a60>)
During: typing of call at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py (219)
File "../../../../../.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py", line 219:
def impl(cls, key_type, value_type):
return dictobject.new_dict(key_type, value_type)
^
raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/typeinfer.py:1071
- Resolution failure for non-literal arguments:
None
During: resolving callee type: BoundFunction((<class 'numba.core.types.abstract.TypeRef'>, 'empty') for typeref[<class 'numba.core.types.containers.DictType'>])
During: typing of call at /home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py (72)
File "priorityqueue.py", line 72:
def __init__(self):
<source elided>
self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
^
During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)
During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)
File "<string>", line 3:
<source missing, REPL/exec in use?>
Process finished with exit code 1
【问题讨论】:
【参考方案1】:由于 numba 中的几个问题,这是不可能的,但如果我理解正确,应该在下一个版本 (0.55) 中修复。作为目前的解决方法,我可以通过编译 llvmlite 0.38.0dev0 和 numba 的主分支来使其工作。我不使用 conda,但通过这种方式获得 llvmlite 和 numba 的预发布显然更容易。
这是我的实现:
from heapq import heappush, heappop
from typing import List, Tuple, Dict, Any
import numba as nb
import numpy as np
from numba.experimental import jitclass
class UpdatablePriorityQueueEntry:
def __init__(self, p: float, i: Any):
self.priority = p
self.item = i
def __lt__(self, other: "UpdatablePriorityQueueEntry"):
return self.priority < other.priority
class UpdatablePriorityQueue:
def __init__(self):
self.pq = []
self.entries_priority =
def put(self, item: Any, priority: float = 0.0):
entry = UpdatablePriorityQueueEntry(priority, item)
self.entries_priority[item] = priority
heappush(self.pq, entry)
def pop(self) -> Any:
while self.pq:
entry = heappop(self.pq)
if entry.priority == self.entries_priority[entry.item]:
self.entries_priority[entry.item] = np.inf
return entry.item
raise KeyError("pop from an empty priority queue")
def clear(self):
self.pq.clear()
self.entries_priority.clear()
@jitclass
class PriorityQueueEntry(UpdatablePriorityQueueEntry):
priority: float
item: Tuple[int, int]
def __init__(self, p: float, i: Tuple[int, int]):
self.priority = p
self.item = i
@jitclass
class UpdatablePriorityQueue(UpdatablePriorityQueue):
pq: List[PriorityQueueEntry2d]
entries_priority: Dict[Tuple[int, int], float]
def __init__(self):
self.pq = nb.typed.List.empty_list(PriorityQueueEntry2d(0.0, (0, 0)))
self.entries_priority = nb.typed.Dict.empty((0, 0), 0.0)
def put(self, item: Tuple[int, int], priority: float = 0.0):
entry = PriorityQueueEntry2d(priority, item)
self.entries_priority[item] = priority
heappush(self.pq, entry)
【讨论】:
【参考方案2】:我遇到了与自定义类 Entry 相关的类似问题。基本上 Numba 无法使用 __lt__(self, other)
来比较条目,并给了我一个 No implementation of function Function(< built-in function lt >)
错误。
所以我想出了以下内容。它适用于 Ubuntu 18.04 上 Python 3.8 上的 Numba 0.55.1。诀窍是避免使用任何自定义类对象作为优先队列项的一部分,以避免上述错误。
from typing import List, Dict, Tuple
from heapq import heappush, heappop
import numba as nb
from numba.experimental import jitclass
# priority, counter, item, removed
entry_def = (0.0, 0, (0,0), nb.typed.List([False]))
entry_type = nb.typeof(entry_def)
@jitclass
class PriorityQueue:
# The following helps numba infer type of variable
pq: List[entry_type]
entry_finder: Dict[Tuple[int, int], entry_type]
counter: int
entry: entry_type
def __init__(self):
# Must declare types here see https://numba.pydata.org/numba-doc/dev/reference/pysupported.html
self.pq = nb.typed.List.empty_list((0.0, 0, (0,0), nb.typed.List([False])))
self.entry_finder = nb.typed.Dict.empty( (0, 0), (0.0, 0, (0,0), nb.typed.List([False])))
self.counter = 0
def put(self, item: Tuple[int, int], priority: float = 0.0):
"""Add a new item or update the priority of an existing item"""
if item in self.entry_finder:
# Mark duplicate item for deletion
self.remove_item(item)
self.counter += 1
entry = (priority, self.counter, item, nb.typed.List([False]))
self.entry_finder[item] = entry
heappush(self.pq, entry)
def remove_item(self, item: Tuple[int, int]):
"""Mark an existing item as REMOVED via True. Raise KeyError if not found."""
self.entry = self.entry_finder.pop(item)
self.entry[3][0] = True
def pop(self):
"""Remove and return the lowest priority item. Raise KeyError if empty."""
while self.pq:
priority, count, item, removed = heappop(self.pq)
if not removed[0]:
del self.entry_finder[item]
return priority, item
raise KeyError("pop from an empty priority queue")
首先定义一个名为entry_def
的全局变量,它将作为我们优先级队列pq
中的条目。 “已删除”的标记现在被替换为 numba.typed.List([False])
,以作为在优先键更改(延迟删除)的情况下跟踪要删除的项目的一种方式。烦人的部分是必须输入pq
和entry_finder
的定义;我无法重用 entry_def
变量。
我可以确认PriorityQueue
的工作方式如下:
q = PriorityQueue()
q.put((1,1), 5.0)
q.put((1,1), 4.0)
q.put((1,1), 3.0)
q.put((1,1), 6.0)
print(q.pq)
>> [(3.0, 3, (1, 1), ListType[bool]([True])), (5.0, 1, (1, 1), ListType[bool]([True])), (4.0, 2, (1, 1), ListType[bool]([True])), (6.0, 4, (1, 1), ListType[bool]([False]))]
print(q.pop())
>> (6.0, (1, 1))
print(len(q.entry_finder))
>> 0
希望有人会发现这很有用或可以提供更好的选择。
【讨论】:
以上是关于如何实现 numba jited 优先级队列?的主要内容,如果未能解决你的问题,请参考以下文章