Cython 生成的 C++ 代码中可能存在的错误

Posted

技术标签:

【中文标题】Cython 生成的 C++ 代码中可能存在的错误【英文标题】:Possible bug in Cython generated C++ code 【发布时间】:2016-06-20 21:06:13 【问题描述】:

我正在尝试为 __gnu_parallel::sort 创建一个 Cython 包装器,其方式与在此线程 Parallel in-place sort for numpy arrays 中所做的相同。

这是我对 wrapparallel.pyx 的简化代码:

import cython
cimport cython 

cdef extern from "<parallel/algorithm>" namespace "__gnu_parallel":
    cdef void sort[T](T first, T last) nogil 

def parallel_sort(double[::1] a):
    sort(&a[0], &a[a.shape[0] - 1])

我使用以下代码生成 c++ 代码:

cython --cplus wrapparallel.pyx

编译和链接:

g++ -g -march=native -Ofast -fpic -c wrapparallel.cpp -o wrapparallel.o -fopenmp -I/usr/include/python2.7 -I/usr/include/x86_64-linux-gnu/python2.7
g++ -g -march=native -Ofast -shared -o wrapparallel.so wrapparallel.o -lpthread -ldl  -lutil -lm  -lpython2.7 -lgomp 

现在来测试一下:

In [1]: import numpy as np
        from wrapparallel import parallel_sort

        a = np.random.randn(10)
        parallel_sort(a)
        a

Out[1]: array([-1.23569683, -1.05639448, -0.76990205, -0.2512839 , -0.25022328,
                0.12711458,  0.81659571,  0.92205287,  2.15019125, -0.45902146])

正如原始线程中的注释中所指出的,此代码不会对最后一个元素进行排序,并且注释者建议在调用中删除“-1”以在 pyx 文件中进行排序。但是,此更改不会解决任何问题,因为 a[a.shape[0]] 会超出范围。

这让我怀疑 c++ 代码中可能存在问题。实际调用 __gnu_parallel::sort 的 sn-p 如下所示:

static PyObject *__pyx_pf_12wrapparallel_parallel_sort(CYTHON_UNUSED PyObject *__pyx_self, __Pyx_memviewslice __pyx_v_a) 
  PyObject *__pyx_r = NULL;
  __Pyx_RefNannyDeclarations
  Py_ssize_t __pyx_t_1;
  int __pyx_t_2;
  Py_ssize_t __pyx_t_3;
  int __pyx_lineno = 0;
  const char *__pyx_filename = NULL;
  int __pyx_clineno = 0;
  __Pyx_RefNannySetupContext("parallel_sort", 0);

  __pyx_t_1 = 0;
  __pyx_t_2 = -1;
  if (__pyx_t_1 < 0) 
    __pyx_t_1 += __pyx_v_a.shape[0];
    if (unlikely(__pyx_t_1 < 0)) __pyx_t_2 = 0;
   else if (unlikely(__pyx_t_1 >= __pyx_v_a.shape[0])) __pyx_t_2 = 0;
  if (unlikely(__pyx_t_2 != -1)) 
    __Pyx_RaiseBufferIndexError(__pyx_t_2);
    __pyx_filename = __pyx_f[0]; __pyx_lineno = 31; __pyx_clineno = __LINE__; goto __pyx_L1_error;
  
  __pyx_t_3 = ((__pyx_v_a.shape[0]) - 1);
  __pyx_t_2 = -1;
  if (__pyx_t_3 < 0) 
    __pyx_t_3 += __pyx_v_a.shape[0];
    if (unlikely(__pyx_t_3 < 0)) __pyx_t_2 = 0;
   else if (unlikely(__pyx_t_3 >= __pyx_v_a.shape[0])) __pyx_t_2 = 0;
  if (unlikely(__pyx_t_2 != -1)) 
    __Pyx_RaiseBufferIndexError(__pyx_t_2);
    __pyx_filename = __pyx_f[0]; __pyx_lineno = 31; __pyx_clineno = __LINE__; goto __pyx_L1_error;
  
  __gnu_parallel::sort<double *>((&(*((double *) ( /* dim=0 */ ((char *) (((double *) __pyx_v_a.data) + __pyx_t_1)) )))), (&(*((double *) ( /* dim=0 */ ((char *) (((double *) __pyx_v_a.data) + __pyx_t_3)) )))));


  /* function exit code */
  __pyx_r = Py_None; __Pyx_INCREF(Py_None);
  goto __pyx_L0;
  __pyx_L1_error:;
  __Pyx_AddTraceback("wrapparallel.parallel_sort", __pyx_clineno, __pyx_lineno, __pyx_filename);
  __pyx_r = NULL;
  __pyx_L0:;
  __PYX_XDEC_MEMVIEW(&__pyx_v_a, 1);
  __Pyx_XGIVEREF(__pyx_r);
  __Pyx_RefNannyFinishContext();
  return __pyx_r;

我的 c++ 知识不足以掌握这里发生的事情,所以我的问题是:调用 __gnu_parallel::sort 是否有问题,我该如何更改它以在 memoryview 中也包含最后一个元素?

编辑:

sort(&amp;a[0], &amp;a[a.shape[0] - 1]) 更改为sort(&amp;a[0], &amp;a[a.shape[0]]) 的答案是正确的。但是,除非指示 cython 编译器使用 boundscheck = False 指令,否则这将引发 IndexError: Out of bounds on buffer access (axis 0) 。为了完整起见,wrapparallel.pyx 文件应如下所示:

# cython: boundscheck = False
import cython
cimport cython 

cdef extern from "<parallel/algorithm>" namespace "__gnu_parallel":
    cdef void sort[T](T first, T last) nogil 

def parallel_sort(double[::1] a):
    sort(&a[0], &a[a.shape[0]])

【问题讨论】:

我的 c++ 知识不足以掌握这里发生的事情——我认为 Stroustrup 本人无法掌握该代码中发生的事情。 【参考方案1】:

告诉您删除-1 的人是对的。排序函数需要类似于range 的参数(例如range(0, 3) &lt;-&gt; [0, 1, 2]

因此,您需要为排序算法提供第一个指针,该指针不在您希望排序的数组中。给定以下数据:

addr | 0x00 | 0x01 | 0x02 | 0x03 |
-----+------+------+------+------+
elem | 3.12 | 5.89 | 0.56 |    - |

你会打电话给sort(addr, &amp;addr[3])

您可以想象排序函数以如下方式遍历数组中的项目:

void func(double *start, double *end) 
    for (double *current = start; current != end; current += 1) 
        double value = *current;
        // do something
    

请注意,当current 指针等于end 时循环停止,end 指针将永远不会被取消引用(访问)。

当您编写&amp;a[a.shape[0]] 时,编译器足够聪明,可以判断出您只是在尝试进行指针运算,而实际上不会取消引用无效指针。

【讨论】:

你可能想做current++

以上是关于Cython 生成的 C++ 代码中可能存在的错误的主要内容,如果未能解决你的问题,请参考以下文章

分发 Cython 生成的 cpp 文件

Cython 和重载的 c++ 构造函数

使用 Cython 和 C++ 组织项目

如何在 IDE 中调试 Cython

在 Cython 中使用 C++ STL 映射

使用 cuda 的 cython 扩展