尝试使用 numba 循环 numpy 数组时出错

Posted

技术标签:

【中文标题】尝试使用 numba 循环 numpy 数组时出错【英文标题】:Getting error when trying to loop numpy arrays with numba 【发布时间】:2022-01-01 21:21:01 【问题描述】:

你好! 我正在尝试使用 numba 来加快 numpy 数组的循环。我的代码:

@nb.njit
def _complicated(heat_list_, raw_array_, x_min_, x_step_, y_min_, y_step_):
        
        array = np.array(raw_array_)
        z_val = array[2, :]
        x_val = array[0, :]
        y_val = array[1, :]

        for i in range(len(z_val)):
            heat_list_[np.round_((y_val[i] - y_min_) / y_step_)][np.round_((x_val[i] - x_min_) / x_step_)] = z_val[i]

        return heat_list_

heat_list_, raw_list_ 是 numpy.ndarraysx_min_, x_step_, y_min_, y_step_ 是浮点数 当我运行这个函数时,出现以下错误:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_28508/1500348391.py in <module>
----> 1 _complicated(data1.heat_list, data1.raw_array, data1.x_min, data1.x_step, data1.y_min, data1.y_step)

~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_for_args(self, *args, **kws)
    418                 e.patch_message(msg)
    419 
--> 420             error_rewrite(e, 'typing')
    421         except errors.UnsupportedError as e:
    422             # Something unsupported is present in the user code, add help info

~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in error_rewrite(e, issue_type)
    359                 raise e
    360             else:
--> 361                 raise e.with_traceback(None)
    362 
    363         argtypes = []

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function array>) found for signature:
 
 >>> array(array(float64, 2d, C))
 
There are 2 candidate implementations:
      - Of which 2 did not match due to:
      Overload in function 'array': File: numba\core\typing\npydecl.py: Line 489.
        With argument(s): '(array(float64, 2d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: array(float64, 2d, C) not allowed in a homogeneous sequence
  raised from C:\Users\nokol\miniconda3\envs\pyquac\lib\site-packages\numba\core\typing\npydecl.py:457

During: resolving callee type: Function(<built-in function array>)
During: typing of call at C:\Users\nokol\AppData\Local\Temp/ipykernel_28508/1750607535.py (14)


File "..\..\..\..\AppData\Local\Temp\ipykernel_28508\1750607535.py", line 14:
<source missing, REPL/exec in use?>

我不明白为什么会这样,因为函数本身很简单,而且其中的数据类型是 numba 支持的。 如果有人能帮我弄清楚,我会很高兴。PS numba 版本 == 0.53.1python 版本 == 3.9.7

第 2 版 我已经稍微更改了我的代码,但仍然出现错误:

@nb.generated_jit(nopython=True)
def _complicated(heat_list_, raw_array_x, raw_array_y, raw_array_z, x_min_, x_step_, y_min_, y_step_):

    for i in range(len(raw_array_y)):
        heat_list_[np.round_((raw_array_y[i] - y_min_) / y_step_), 
                   np.round_((raw_array_x[i] - x_min_) / x_step_)] = raw_array_z[i]

    return heat_list_
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_17704/3567269810.py in <module>
----> 1 data1.njit_result()

~\AppData\Local\Temp/ipykernel_17704/1768352675.py in njit_result(self)
    336 
    337         if len(self.x_raw) >= 2:
--> 338             return _complicated(self.heat_list, x_val, y_val, z_val,
    339                                 self.x_min, self.x_step, self.y_min, self.y_step, len_of_z)
    340         else:

~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_for_args(self, *args, **kws)
    499                     e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
    500             # ignore the FULL_TRACEBACKS config, this needs reporting!
--> 501             raise e
    502         finally:
    503             self._types_active_call = []

~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_for_args(self, *args, **kws)
    432         return_val = None
    433         try:
--> 434             return_val = self.compile(tuple(argtypes))
    435         except errors.ForceLiteralArg as e:
    436             # Received request for compiler re-entry with the list of arguments

~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in compile(self, sig)
    977                 with ev.trigger_event("numba:compile", data=ev_details):
    978                     try:
--> 979                         cres = self._compiler.compile(args, return_type)
    980                     except errors.ForceLiteralArg as e:
    981                         def folded(args, kws):

~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in compile(self, args, return_type)
    139 
    140     def compile(self, args, return_type):
--> 141         status, retval = self._compile_cached(args, return_type)
    142         if status:
    143             return retval

~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_cached(self, args, return_type)
    153 
    154         try:
--> 155             retval = self._compile_core(args, return_type)
    156         except errors.TypingError as e:
    157             self._failed_cache[key] = e

~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _compile_core(self, args, return_type)
    165         flags = self._customize_flags(flags)
    166 
--> 167         impl = self._get_implementation(args, )
    168         cres = compiler.compile_extra(self.targetdescr.typing_context,
    169                                       self.targetdescr.target_context,

~\miniconda3\envs\pyquac\lib\site-packages\numba\core\dispatcher.py in _get_implementation(self, args, kws)
    201 
    202     def _get_implementation(self, args, kws):
--> 203         impl = self.py_func(*args, **kws)
    204         # Check the generating function and implementation signatures are
    205         # compatible, otherwise compiling would fail later.

~\AppData\Local\Temp/ipykernel_17704/1768352675.py in _complicated(heat_list_, raw_array_x, raw_array_y, raw_array_z, x_min_, x_step_, y_min_, y_step_, len_of_z)
     12 def _complicated(heat_list_, raw_array_x, raw_array_y, raw_array_z, x_min_, x_step_, y_min_, y_step_,
     13                  len_of_z):
---> 14     for i in range(len(raw_array_y)):
     15         heat_list_[np.round_((raw_array_y[i] - y_min_) / y_step_), 
     16                    np.round_((raw_array_x[i] - x_min_) / x_step_)] = raw_array_z[i]

TypeError: object of type 'Array' has no len()

【问题讨论】:

【参考方案1】:

与 Numpy 不同,Numba 不支持使用已经是 Numpy 数组的参数调用 np.array。这似乎是 Numpy 的一个错误或不受支持的功能(如果是这样,可以在 bug tracker 上报告)。在您的情况下,这并不重要,因为不需要此调用,因为输入已经是一个 Numpy 数组。如果情况并非总是如此,那么您可能需要使用 generated-jit 装饰器。但是,一个更简单的解决方案是确保输入参数是一个 Numpy 数组,否则,在调用 Numba 函数之前进行转换

【讨论】:

我尝试了你所说的一切,但现在它给出了一个相当奇怪的错误,指的是...range(len(z_val))。它说TypeError: 'Integer' object cannot be interpreted as an integer 还有TypeError: object of type 'Array' has no len() 在您调用该函数时,输入类型似乎很奇怪。我无法使用 float64[:,:,:] 类型的 raw_list_ 重现此特定问题。您能否提供准确的输入类型(即数组的维度、项目的类型)?除此之外,建议您直接在装饰器中为 Numba 指定类型,以便更早地捕获输入错误。

以上是关于尝试使用 numba 循环 numpy 数组时出错的主要内容,如果未能解决你的问题,请参考以下文章

强化学习技巧五:numba提速python程序

为啥在迭代 NumPy 数组时 Cython 比 Numba 慢得多?

各种 numpy 花式索引方法的性能,也与 numba

如何使用numba加快以下代码的速度?

对于纯 numpy 代码,使用 numba 的收益在哪里?

将数据从CSV转换为numpy数组时出错