如何通过 Fortran 中的 BLAS 加速高阶张量收缩的重塑?

How to speed up reshape in higher rank tensor contraction by BLAS in Fortran?

相关问题Fortran: Which method is faster to change the rank of arrays? (Reshape vs. Pointer)

如果我有张量收缩 A[a,b] * B[b,c,d] = C[a,c,d] 如果我使用BLAS,我想我需要DGEMM(假设真实值),那么我可以

  1. 首先将张量 B[b,c,d] 重塑为 D[b,e] 其中 e = c*d,
  2. DGEMM,A[a,b] * D[b,e] = E[a,e]
  3. E[a,e] 重塑为 C[a,c,d]

问题是,reshape 没那么快 :( 我在 Fortran: Which method is faster to change the rank of arrays? (Reshape vs. Pointer) 中看到了讨论 ,在上面link,作者遇到了一些错误信息,除了reshape itself.

所以请问有没有方便的解决方法

[我在维度的大小前面加上字母 n 以避免在下面混淆张量和张量的大小]

如评论中所述,无需重塑。 Dgemm没有张量的概念,只知道数组。它所关心的只是那些数组在内存中以正确的顺序排列。由于 Fortran 是列专业,如果您使用 3 维数组来表示问题中的 3 维张量 B,它将在内存中 完全 与用于表示的 2 维数组相同二维张量 D。就矩阵 mult 而言,您现在需要做的就是获取形成正确长度结果的点积。这使您得出结论,如果您告诉 dgemm B 的前导 dim 为 nb,第二个 dim 为 nc*nd,您将得到正确的结果。这导致我们

ian@eris:~/work/stack$ gfortran --version
GNU Fortran (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

ian@eris:~/work/stack$ cat reshape.f90
Program reshape_for_blas

  Use, Intrinsic :: iso_fortran_env, Only :  wp => real64, li => int64

  Implicit None

  Real( wp ), Dimension( :, :    ), Allocatable :: a
  Real( wp ), Dimension( :, :, : ), Allocatable :: b
  Real( wp ), Dimension( :, :, : ), Allocatable :: c1, c2
  Real( wp ), Dimension( :, :    ), Allocatable :: d
  Real( wp ), Dimension( :, :    ), Allocatable :: e

  Integer :: na, nb, nc, nd, ne
  
  Integer( li ) :: start, finish, rate

  Write( *, * ) 'na, nb, nc, nd ?'
  Read( *, * ) na, nb, nc, nd
  ne = nc * nd
  Allocate( a ( 1:na, 1:nb ) ) 
  Allocate( b ( 1:nb, 1:nc, 1:nd ) ) 
  Allocate( c1( 1:na, 1:nc, 1:nd ) ) 
  Allocate( c2( 1:na, 1:nc, 1:nd ) ) 
  Allocate( d ( 1:nb, 1:ne ) ) 
  Allocate( e ( 1:na, 1:ne ) ) 

  ! Set up some data
  Call Random_number( a )
  Call Random_number( b )

  ! With reshapes
  Call System_clock( start, rate )
  d = Reshape( b, Shape( d ) )
  Call dgemm( 'N', 'N', na, ne, nb, 1.0_wp, a, Size( a, Dim = 1 ), &
                                            d, Size( d, Dim = 1 ), &
                                    0.0_wp, e, Size( e, Dim = 1 ) )
  c1 = Reshape( e, Shape( c1 ) )
  Call System_clock( finish, rate )
  Write( *, * ) 'Time for reshaping method ', Real( finish - start, wp ) / rate
  
  ! Direct
  Call System_clock( start, rate )
  Call dgemm( 'N', 'N', na, ne, nb, 1.0_wp, a , Size( a , Dim = 1 ), &
                                            b , Size( b , Dim = 1 ), &
                                            0.0_wp, c2, Size( c2, Dim = 1 ) )
  Call System_clock( finish, rate )
  Write( *, * ) 'Time for straight  method ', Real( finish - start, wp ) / rate

  Write( *, * ) 'Difference between result matrices ', Maxval( Abs( c1 - c2 ) )

End Program reshape_for_blas
ian@eris:~/work/stack$ cat in
40 50 60 70
ian@eris:~/work/stack$ gfortran -std=f2008 -Wall -Wextra -fcheck=all reshape.f90  -lblas
ian@eris:~/work/stack$ ./a.out < in
 na, nb, nc, nd ?
 Time for reshaping method    1.0515256000000001E-002
 Time for straight  method    5.8608790000000003E-003
 Difference between result matrices    0.0000000000000000     
ian@eris:~/work/stack$ gfortran -std=f2008 -Wall -Wextra  reshape.f90  -lblas
ian@eris:~/work/stack$ ./a.out < in
 na, nb, nc, nd ?
 Time for reshaping method    1.3585931000000001E-002
 Time for straight  method    1.6730429999999999E-003
 Difference between result matrices    0.0000000000000000     

也就是说,我认为值得注意的是,重塑的开销是 O(N^2),而矩阵乘法的时间是 O(N^3)。因此,对于大型矩阵,由于重塑而导致的开销百分比将趋于零。现在代码性能不是唯一的考虑因素,代码的可读性和可维护性也很重要。因此,如果您发现 reshape 方法更具可读性并且您使用的矩阵足够大以至于开销不是重要的,那么您可以很好地使用 reshape,因为在这种情况下代码可读性可能比性能更重要。你来电。