尝试使用 numba 循环 numpy 数组时出错
Posted
技术标签:
【中文标题】尝试使用 numba 循环 numpy 数组时出错【英文标题】:Getting error when trying to loop numpy arrays with numba 【发布时间】:2022-01-01 21:21:01 【问题描述】:你好! 我正在尝试使用 numba 来加快 numpy 数组的循环。我的代码:
@nb.njit
def _complicated(heat_list_, raw_array_, x_min_, x_step_, y_min_, y_step_):
array = np.array(raw_array_)
z_val = array[2, :]
x_val = array[0, :]
y_val = array[1, :]
for i in range(len(z_val)):
heat_list_[np.round_((y_val[i] - y_min_) / y_step_)][np.round_((x_val[i] - x_min_) / x_step_)] = z_val[i]
return heat_list_
heat_list_
, raw_list_
是 numpy.ndarraysx_min_, x_step_, y_min_, y_step_
是浮点数
当我运行这个函数时,出现以下错误:
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_28508/1500348391.py in <module>
----> 1 _complicated(data1.heat_list, data1.raw_array, data1.x_min, data1.x_step, data1.y_min, data1.y_step)
~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_for_args(self, *args, **kws)
418 e.patch_message(msg)
419
--> 420 error_rewrite(e, 'typing')
421 except errors.UnsupportedError as e:
422 # Something unsupported is present in the user code, add help info
~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in error_rewrite(e, issue_type)
359 raise e
360 else:
--> 361 raise e.with_traceback(None)
362
363 argtypes = []
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function array>) found for signature:
>>> array(array(float64, 2d, C))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'array': File: numba\core\typing\npydecl.py: Line 489.
With argument(s): '(array(float64, 2d, C))':
Rejected as the implementation raised a specific error:
TypingError: array(float64, 2d, C) not allowed in a homogeneous sequence
raised from C:\Users\nokol\miniconda3\envs\pyquac\lib\site-packages\numba\core\typing\npydecl.py:457
During: resolving callee type: Function(<built-in function array>)
During: typing of call at C:\Users\nokol\AppData\Local\Temp/ipykernel_28508/1750607535.py (14)
File "..\..\..\..\AppData\Local\Temp\ipykernel_28508\1750607535.py", line 14:
<source missing, REPL/exec in use?>
我不明白为什么会这样,因为函数本身很简单,而且其中的数据类型是 numba 支持的。 如果有人能帮我弄清楚,我会很高兴。PS numba 版本 == 0.53.1python 版本 == 3.9.7
第 2 版 我已经稍微更改了我的代码,但仍然出现错误:
@nb.generated_jit(nopython=True)
def _complicated(heat_list_, raw_array_x, raw_array_y, raw_array_z, x_min_, x_step_, y_min_, y_step_):
for i in range(len(raw_array_y)):
heat_list_[np.round_((raw_array_y[i] - y_min_) / y_step_),
np.round_((raw_array_x[i] - x_min_) / x_step_)] = raw_array_z[i]
return heat_list_
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_17704/3567269810.py in <module>
----> 1 data1.njit_result()
~\AppData\Local\Temp/ipykernel_17704/1768352675.py in njit_result(self)
336
337 if len(self.x_raw) >= 2:
--> 338 return _complicated(self.heat_list, x_val, y_val, z_val,
339 self.x_min, self.x_step, self.y_min, self.y_step, len_of_z)
340 else:
~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_for_args(self, *args, **kws)
499 e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
500 # ignore the FULL_TRACEBACKS config, this needs reporting!
--> 501 raise e
502 finally:
503 self._types_active_call = []
~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_for_args(self, *args, **kws)
432 return_val = None
433 try:
--> 434 return_val = self.compile(tuple(argtypes))
435 except errors.ForceLiteralArg as e:
436 # Received request for compiler re-entry with the list of arguments
~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in compile(self, sig)
977 with ev.trigger_event("numba:compile", data=ev_details):
978 try:
--> 979 cres = self._compiler.compile(args, return_type)
980 except errors.ForceLiteralArg as e:
981 def folded(args, kws):
~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in compile(self, args, return_type)
139
140 def compile(self, args, return_type):
--> 141 status, retval = self._compile_cached(args, return_type)
142 if status:
143 return retval
~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_cached(self, args, return_type)
153
154 try:
--> 155 retval = self._compile_core(args, return_type)
156 except errors.TypingError as e:
157 self._failed_cache[key] = e
~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_core(self, args, return_type)
165 flags = self._customize_flags(flags)
166
--> 167 impl = self._get_implementation(args, )
168 cres = compiler.compile_extra(self.targetdescr.typing_context,
169 self.targetdescr.target_context,
~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _get_implementation(self, args, kws)
201
202 def _get_implementation(self, args, kws):
--> 203 impl = self.py_func(*args, **kws)
204 # Check the generating function and implementation signatures are
205 # compatible, otherwise compiling would fail later.
~\AppData\Local\Temp/ipykernel_17704/1768352675.py in _complicated(heat_list_, raw_array_x, raw_array_y, raw_array_z, x_min_, x_step_, y_min_, y_step_, len_of_z)
12 def _complicated(heat_list_, raw_array_x, raw_array_y, raw_array_z, x_min_, x_step_, y_min_, y_step_,
13 len_of_z):
---> 14 for i in range(len(raw_array_y)):
15 heat_list_[np.round_((raw_array_y[i] - y_min_) / y_step_),
16 np.round_((raw_array_x[i] - x_min_) / x_step_)] = raw_array_z[i]
TypeError: object of type 'Array' has no len()
【问题讨论】:
【参考方案1】:与 Numpy 不同,Numba 不支持使用已经是 Numpy 数组的参数调用 np.array
。这似乎是 Numpy 的一个错误或不受支持的功能(如果是这样,可以在 bug tracker 上报告)。在您的情况下,这并不重要,因为不需要此调用,因为输入已经是一个 Numpy 数组。如果情况并非总是如此,那么您可能需要使用 generated-jit
装饰器。但是,一个更简单的解决方案是确保输入参数是一个 Numpy 数组,否则,在调用 Numba 函数之前进行转换。
【讨论】:
我尝试了你所说的一切,但现在它给出了一个相当奇怪的错误,指的是...range(len(z_val))
。它说TypeError: 'Integer' object cannot be interpreted as an integer
还有TypeError: object of type 'Array' has no len()
在您调用该函数时,输入类型似乎很奇怪。我无法使用 float64[:,:,:]
类型的 raw_list_
重现此特定问题。您能否提供准确的输入类型(即数组的维度、项目的类型)?除此之外,建议您直接在装饰器中为 Numba 指定类型,以便更早地捕获输入错误。以上是关于尝试使用 numba 循环 numpy 数组时出错的主要内容,如果未能解决你的问题,请参考以下文章