如何在 jitclass 中创建一个 numpy 数组列表

how to create a list of numpy arrays in jitclass

我想创建一个 jitclass,它将存储一些 numpy 数组。我不知道其中有多少。所以我想创建一个 numpy 数组列表。 我对 numba 类型感到困惑,但发现了一些奇怪的解决方案。 这运行正常。

import numba
from numba import types, typed, typeof
from numba.experimental import jitclass
import numpy as np


spec = [
    ('test', typeof(typed.List.empty_list(numba.int64[:])))
]

@jitclass(spec)
class myLIST(object):
    def __init__ (self, haha=typed.List.empty_list(numba.int64[:])):
        self.test = haha
        self.test.append(np.asarray([0]))

    def dump(self):
        self.test.append(np.asarray([1]))
        print(self.test)

a = myLIST()
a.dump()

但是当我删除冗余变量时,它失败了。

spec = [
    ('test', typeof(typed.List.empty_list(numba.int64[:])))
]

@jitclass(spec)
class myLIST(object):
    def __init__ (self):
        self.test = typed.List.empty_list(numba.int64[:])
        self.test.append(np.asarray([0]))

    def dump(self):
        self.test.append(np.asarray([1]))
        print(self.test)

a = myLIST()
a.dump()

为什么会这样?

似乎将数组类型声明为 nb.int64[:] 并不能提供足够的信息来创建 class,除非您创建 Numba 的实例(haha 的默认值)可以用来推断类型。

相反,您可以声明:

int_vector = nb.types.Array(dtype=nb.int64, ndim=1, layout="C")
spec = [('test', nb.typeof(nb.typed.List.empty_list(int_vector)))]

或更短:

int_vector = nb.types.Array(dtype=nb.int64, ndim=1, layout="C")
spec = [('test', nb.types.ListType(int_vector))]

或者,如果您可以使用类型注释:

int_vector = nb.types.Array(dtype=nb.int64, ndim=1, layout="C")

@nb.experimental.jitclass
class my_list:

    test: nb.types.ListType(int_vector)

    def __init__(self):
        self.test = nb.typed.List.empty_list(int_vector)
        self.test.append(np.array([0]))

    def dump(self):
        self.test.append(np.array([1]))
        print(self.test)

a = my_list()
a.dump()