Cython lapack 不会覆盖参数

Cython lapack does not overwrite parameters

我正在学习Cython lapack接口。我尝试使用 cython_lapack 包将一个在 C 中使用 lapack 的小示例转换为一个在 Cython 中使用的小示例。

C代码如下:

#include <stdio.h>
#include <lapacke.h>

int main()
{
  int i, j;
  float norm;
  float A[] = {1.44, -7.84, -4.39, 4.53, 
               -9.96, -0.28, -3.24,  3.83,
               -7.55, 3.24, 6.27, -6.64,
               8.34, 8.09, 5.28, 2.06,
               7.08, 2.52, 0.74, -2.47,
               -5.45, -5.70, -1.19, 4.70};
  float B[] = {8.58, 9.35,
               8.26, -4.43,
               8.48, -0.70,
               -5.28, -0.26,
               5.72, -7.36,
               8.93, -2.52};
  lapack_int m = 6, n = 4, nrhs = 2;
  int info;

  info = LAPACKE_sgels(LAPACK_ROW_MAJOR, 'N', m, n, nrhs, A, n, B, nrhs);
  printf("Least squares solution:\n");
  for (i = 0; i < n; ++i) {
    for (j = 0; j < nrhs; ++j) printf(" %6.2f", B[nrhs * i + j]);
    printf("\n");
  }
  printf("Residual sum of squares for the solution:\n");
  for (i = 0; i < nrhs; ++i) {
    norm = 0.0;
    for (j = n; j < m; ++j) norm += B[nrhs * j + i] * B[nrhs * j + i];
  printf(" %6.2f", norm);
 }
 printf("\n");
 return 0;
}

我翻译成Cython,如下图:

import numpy as np
cimport numpy as cnp
from scipy.linalg.cython_lapack cimport dgels

a = np.array([1.44, -9.96, -7.55, 8.34, 7.08, -5.45,
              7.84, -0.28, 3.24, 8.09, 2.52, -5.70,
              -4.39, -3.24, 6.27, 5.28, 0.74, -1.19,
              4.53, 3.83, -6.64, 2.06, -2.47, 4.70], dtype=np.float64)
b = np.array([8.58, 8.26, 8.48, -5.28, 5.72, 8.93,
              9.35, -4.43, -.70, -0.26, -7.36, -2.52], dtype=np.float64)

cdef int m = 6, n = 4, nrhs = 2, lda = 6, ldb = 6, lwork = -1, info = 1
cdef double work = 0.0

cdef double* A = <double*>(cnp.PyArray_DATA(a))
cdef double* B = <double*>(cnp.PyArray_DATA(b))

dgels('N', &m, &n, &nrhs, A, &lda, B, &ldb, &work, &lwork, &info)
print "A:"
for j in range(4):
    for i in range(6):
        print "%6.2f " % A[6*j+i],
    print ""
print "B:"
for j in range(2):
    for i in range(6):
        print "%6.2f " % B[6*j+i],
    print ""

我原以为 A 和 B 会被修改,就像 C 代码中的行为一样。然而,奇怪的是,它们根本没有变化。我想知道我应该怎么做才能得到正确的结果。

来自the Netlib documentation

If LWORK = -1, then a workspace query is assumed; the routine only calculates the optimal size of the WORK array, returns this value as the first entry of the WORK array, and no error message related to LWORK is issued by XERBLA.

您没有要求进行计算 - 您只是询问了应该使工作空间数组有多大。你可以

  1. 使用适当大小的工作区数组进行第二次调用,或者
  2. 只需根据文档分配一个工作空间数组(最好是 max( 1, MN + max( MN, NRHS )*NB ))并将其大小作为 LWORK 传递。