Cython 定义函数类型

Cython defining type for functions

我正在尝试制作一个用 cython 构建的切片采样库。一个通用切片采样库,您可以在其中提供对数密度、起始值并获得结果。现在研究单变量模型。根据回复 here,我得出以下结论。

所以我在cSlice.pyx中定义了一个函数:

cdef double univariate_slice_sample(f_type_1 logd, double starter, 
                                        double increment_size = 0.5):
    some stuff
    return value

我在cSlice.pxd中定义了:

cdef ctypedef double (*f_type_1)(double)
cdef double univariate_slice_sample(f_type_1 logd, double starter, 
                                               double increment_size = *)

其中 logd 是通用单变量对数密度。

在我的分发文件中,假设 cDistribution.pyx,我有以下内容:

from cSlice cimport univariate_slice_sample, f_type_1

cdef double log_distribution(alpha_k, y_k, prior):
    some stuff
    return value

cdef double _sample_alpha_k_slice(
        double starter,
        double[:] y_k,
        Prior prior,
        double increment_size
        ):
    cdef f_type_1 f = lambda alpha_k: log_distribution(alpha_k), y_k, prior)
    return univariate_slice_sample(f, starter, increment_size)

cpdef double sample_alpha_k_slice(
        double starter,
        double[:] y_1,
        Prior prior,
        double increment_size = 0.5
        ):
    return _sample_alpha_1_slice(starter, y_1, prior, increment_size)

包装器,因为显然 cpdef 中不允许使用 lambda

当我尝试编译分发文件时,我得到以下信息:

cDistribution.pyx:289:22: Cannot convert Python object to 'f_type_1'

指向 cdef f_type_1 f = ... 行。

我不确定还能做什么。我希望此代码保持 C 速度,重要的是不要撞到 GIL。有什么想法吗?

您可以为任何 Python 函数 jit C-callback/wrapper(从 Python-object 转换为指针不能隐式完成),例如如何在 中解释.

但是,该函数的核心将保持缓慢的纯 Python 函数。 Numba 让您可以通过 @cfunc 创建真正的 C-callbacks。这是一个简化的例子:

from numba import cfunc 
@cfunc("float64(float64)")
def id_(x):
    return x

这是它的用法:

%%cython
ctypedef double(*f_type)(double)

cdef void c_print_double(double x, f_type f):
    print(2.0*f(x))

import numba
expected_signature = numba.float64(numba.float64)
def print_double(double x,f):
    # check the signature of f:
    if not f._sig == expected_signature:
        raise TypeError("cfunc has not the right type")
    # it is not possible to cast a Python object to a pointer directly,
    # so we cast the address first to unsigned long long
    c_print_double(x, <f_type><unsigned long long int>(f.address))

现在:

print_double(1.0, id_)
# 2.0

我们需要在 运行 时间内检查 cfunc-object 的签名,否则转换 <f_type><unsigned long long int>(f.address) 也会对签名错误的函数“起作用”——仅在通话期间(可能)崩溃或给有趣的调试错误带来困难。我只是不确定我的方法是否是最好的 - 即使它有效:

...
@cfunc("float32(float32)")
def id3_(x):
    return x

print_double(1.0, id3_)
# TypeError: cfunc has not the right type