使用 SSE 以最快的速度实现自然指数函数
Fastest Implementation of the Natural Exponential Function Using SSE
我正在寻找对 SSE 元素运行的自然指数函数的近似值。即——__m128 exp( __m128 x )
。
我有一个实现速度很快但准确性似乎很低的实现:
static inline __m128 FastExpSse(__m128 x)
{
__m128 a = _mm_set1_ps(12102203.2f); // (1 << 23) / ln(2)
__m128i b = _mm_set1_epi32(127 * (1 << 23) - 486411);
__m128 m87 = _mm_set1_ps(-87);
// fast exponential function, x should be in [-87, 87]
__m128 mask = _mm_cmpge_ps(x, m87);
__m128i tmp = _mm_add_epi32(_mm_cvtps_epi32(_mm_mul_ps(a, x)), b);
return _mm_and_ps(_mm_castsi128_ps(tmp), mask);
}
任何人都可以实现更准确但速度更快(或更快)的实现吗?
如果它是用 C 风格编写的,我会很高兴。
谢谢。
下面的 C 代码是我在 previous answer 类似问题中使用的算法的 SSE 内在函数的翻译。
基本思想是将标准指数函数的计算转化为2的幂的计算:expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504)
。我们将 t = x * 1.44269504
拆分为一个整数 i
和一个分数 f
,这样 t = i + f
和 0 <= f <= 1
。我们现在可以用多项式近似计算 2f,然后通过将 i
添加到单精度浮点结果。
SSE 实现存在的一个问题是我们想要计算 i = floorf (t)
,但是没有快速的方法来计算 floor()
函数。然而,我们观察到对于正数,floor(x) == trunc(x)
,对于负数,floor(x) == trunc(x) - 1
,除非 x
是负整数。但是,由于核心近似可以处理 1.0f
的 f
值,因此对负参数使用近似是无害的。 SSE提供了一个指令将单精度浮点操作数转换为带截断的整数,所以这个解决方案是高效的。
Peter Cordes指出SSE4.1支持fast floor函数_mm_floor_ps()
,所以下面也展示了一个使用SSE4.1的变体。当启用 SSE 4.1 代码生成时,并非所有工具链都会自动预定义宏 __SSE4_1__
,但 gcc 会。
Compiler Explorer (Godbolt) 显示 gcc 7.2 将以下代码编译为 sixteen instructions for plain SSE and twelve instructions for SSE 4.1。
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <emmintrin.h>
#ifdef __SSE4_1__
#include <smmintrin.h>
#endif
/* max. rel. error = 1.72863156e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, e, p, r;
__m128i i, j;
__m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
__m128 c0 = _mm_set1_ps (0.3371894346f);
__m128 c1 = _mm_set1_ps (0.657636276f);
__m128 c2 = _mm_set1_ps (1.00172476f);
/* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */
t = _mm_mul_ps (x, l2e); /* t = log2(e) * x */
#ifdef __SSE4_1__
e = _mm_floor_ps (t); /* floor(t) */
i = _mm_cvtps_epi32 (e); /* (int)floor(t) */
#else /* __SSE4_1__*/
i = _mm_cvttps_epi32 (t); /* i = (int)t */
j = _mm_srli_epi32 (_mm_castps_si128 (x), 31); /* signbit(t) */
i = _mm_sub_epi32 (i, j); /* (int)t - signbit(t) */
e = _mm_cvtepi32_ps (i); /* floor(t) ~= (int)t - signbit(t) */
#endif /* __SSE4_1__*/
f = _mm_sub_ps (t, e); /* f = t - floor(t) */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
j = _mm_slli_epi32 (i, 23); /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
int main (void)
{
union {
float f[4];
unsigned int i[4];
} arg, res;
double relerr, maxrelerr = 0.0;
int i, j;
__m128 x, y;
float start[2] = {-0.0f, 0.0f};
float finish[2] = {-87.33654f, 88.72283f};
for (i = 0; i < 2; i++) {
arg.f[0] = start[i];
arg.i[1] = arg.i[0] + 1;
arg.i[2] = arg.i[0] + 2;
arg.i[3] = arg.i[0] + 3;
do {
memcpy (&x, &arg, sizeof(x));
y = fast_exp_sse (x);
memcpy (&res, &y, sizeof(y));
for (j = 0; j < 4; j++) {
double ref = exp ((double)arg.f[j]);
relerr = fabs ((res.f[j] - ref) / ref);
if (relerr > maxrelerr) {
printf ("arg=% 15.8e res=%15.8e ref=%15.8e err=%15.8e\n",
arg.f[j], res.f[j], ref, relerr);
maxrelerr = relerr;
}
}
arg.i[0] += 4;
arg.i[1] += 4;
arg.i[2] += 4;
arg.i[3] += 4;
} while (fabsf (arg.f[3]) < fabsf (finish[i]));
}
printf ("maximum relative errror = %15.8e\n", maxrelerr);
return EXIT_SUCCESS;
}
fast_sse_exp()
的另一种设计以舍入到最近的模式提取调整参数 x / log(2)
的整数部分,使用众所周知的添加 "magic" 转换的技术constant 1.5 * 223 强制舍入到正确的位位置,然后再次减去相同的数字。这要求在加法期间有效的 SSE 舍入模式是 "round to nearest or even",这是默认值。 wim在评论中指出,一些编译器可能会在使用激进优化时将转换常量cvt
的加减运算优化为冗余,干扰此代码序列的功能,因此建议检查生成的机器代码。计算 2f 的近似区间现在以零为中心,因为 -0.5 <= f <= 0.5
,需要不同的核心近似。
/* max. rel. error <= 1.72860465e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, p, r;
__m128i i, j;
const __m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
const __m128 cvt = _mm_set1_ps (12582912.0f); /* 1.5 * (1 << 23) */
const __m128 c0 = _mm_set1_ps (0.238428936f);
const __m128 c1 = _mm_set1_ps (0.703448006f);
const __m128 c2 = _mm_set1_ps (1.000443142f);
/* exp(x) = 2^i * 2^f; i = rint (log2(e) * x), -0.5 <= f <= 0.5 */
t = _mm_mul_ps (x, l2e); /* t = log2(e) * x */
r = _mm_sub_ps (_mm_add_ps (t, cvt), cvt); /* r = rint (t) */
f = _mm_sub_ps (t, r); /* f = t - rint (t) */
i = _mm_cvtps_epi32 (t); /* i = (int)t */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= exp2(f) */
j = _mm_slli_epi32 (i, 23); /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
问题中代码的算法似乎取自 Nicol N. Schraudolph 的作品,它巧妙地利用了 IEEE-754 二进制浮点格式的半对数性质:
N. N. Schraudolph. "A fast, compact approximation of the exponential function." 神经计算,11(4),1999 年 5 月,第 853-862 页。
删除参数限制代码后,它减少到只有三个 SSE 指令。 "magical" 校正常数 486411
对于最小化整个输入域的最大相对误差来说不是最佳的。基于简单的二进制搜索,值 298765
似乎更好,将 FastExpSse()
的最大相对误差降低到 3.56e-2,而 fast_exp_sse()
的最大相对误差为 1.73e-3。
/* max. rel. error = 3.55959567e-2 on [-87.33654, 88.72283] */
__m128 FastExpSse (__m128 x)
{
__m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
__m128i b = _mm_set1_epi32 (127 * (1 << 23) - 298765);
__m128i t = _mm_add_epi32 (_mm_cvtps_epi32 (_mm_mul_ps (a, x)), b);
return _mm_castsi128_ps (t);
}
Schraudolph的算法基本上是对[0,1]中的f
使用线性逼近2f~=1.0 + f
,精度有待提高通过添加二次项。 Schraudolph 方法的聪明之处在于计算 2i * 2f 而没有明确地将整数部分 i = floor(x * 1.44269504)
与分数分开。我看不出有什么办法可以将这个技巧扩展到二次近似,但是可以肯定地将 Schraudolph 的 floor()
计算与上面使用的二次近似结合起来:
/* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 f, p, r;
__m128i t, j;
const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
const __m128 c0 = _mm_set1_ps (0.3371894346f);
const __m128 c1 = _mm_set1_ps (0.657636276f);
const __m128 c2 = _mm_set1_ps (1.00172476f);
t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
j = _mm_and_si128 (t, m); /* j = (int)(floor (x/log(2))) << 23 */
t = _mm_sub_epi32 (t, j);
f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
通过使用 FastExpSse(x/2)/FastExpSse(- x/2) 而不是 FastExpSse(x)。这里的技巧是将偏移参数(上面的 298765)设置为零,以便分子和分母中的分段线性近似值对齐,从而消除大量误差。将其合并为一个函数:
__m128 BetterFastExpSse (__m128 x)
{
const __m128 a = _mm_set1_ps ((1 << 22) / float(M_LN2)); // to get exp(x/2)
const __m128i b = _mm_set1_epi32 (127 * (1 << 23)); // NB: zero shift!
__m128i r = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
__m128i s = _mm_add_epi32 (b, r);
__m128i t = _mm_sub_epi32 (b, r);
return _mm_div_ps (_mm_castsi128_ps (s), _mm_castsi128_ps (t));
}
(我不是硬件专家 - 这里的性能杀手有多糟糕?)
如果您需要 exp(x) 只是为了得到 y = tanh(x)(例如对于神经网络),请按如下方式使用具有零偏移的 FastExpSse:
a = FastExpSse(x);
b = FastExpSse(-x);
y = (a - b)/(a + b);
获得相同类型的错误取消福利。 logistic 函数的工作原理类似,使用零偏移的 FastExpSse(x/2)/(FastExpSse(x/2) + FastExpSse(-x/2))。 (这只是为了说明原理——您显然不想在这里多次计算 FastExpSse,而是按照上面 BetterFastExpSse 的方式将其合并为一个函数。)
我确实从中开发了一系列高阶近似值,更准确但也更慢。未发布,但如果有人想试一试,我们很乐意合作。
最后,为了一些乐趣:使用倒档获得 FastLogSse。将其与 FastExpSse 链接在一起可以同时消除运算符和错误,并弹出一个极快的幂函数...
回顾我当时的笔记,我确实探索了在不使用除法的情况下提高准确性的方法。我使用了相同的 reinterpret-as-float 技巧,但对尾数应用了多项式校正,这基本上是用 16 位定点算法计算的(当时唯一的快速计算方法)。
立方四次版本给你 4 resp。准确度的 5 位有效数字。没有必要增加阶数,因为低精度算术的噪声随后开始淹没多项式逼近的误差。以下是纯 C 版本:
#include <stdint.h>
float fastExp3(register float x) // cubic spline approximation
{
union { float f; int32_t i; } reinterpreter;
reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
int32_t m = (reinterpreter.i >> 7) & 0xFFFF; // copy mantissa
// empirical values for small maximum relative error (8.34e-5):
reinterpreter.i +=
((((((((1277*m) >> 14) + 14825)*m) >> 14) - 79749)*m) >> 11) - 626;
return reinterpreter.f;
}
float fastExp4(register float x) // quartic spline approximation
{
union { float f; int32_t i; } reinterpreter;
reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
int32_t m = (reinterpreter.i >> 7) & 0xFFFF; // copy mantissa
// empirical values for small maximum relative error (1.21e-5):
reinterpreter.i += (((((((((((3537*m) >> 16)
+ 13668)*m) >> 18) + 15817)*m) >> 14) - 80470)*m) >> 11);
return reinterpreter.f;
}
四次方服从 (fastExp4(0f) == 1f),这对于定点迭代算法很重要。
SSE 中这些整数乘移加序列的效率如何?在浮点算术同样快的体系结构上,可以使用它来代替,从而减少算术噪声。这基本上会产生上面@njuffa 的答案的三次和四次扩展。
有一篇关于创建这些方程(tanh、cosh、artanh、sinh 等)的快速版本的论文:
http://ijeais.org/wp-content/uploads/2018/07/IJAER180702.pdf
"Creating a Compiler Optimized Inlineable Implementation of Intel Svml Simd Intrinsics"
他们第 9 页的 tanh 方程 6 与@NicSchraudolph 的回答非常相似
对于 softmax 的使用,我将流程设想为:
auto a = _mm_mul_ps(x, _mm_set1_ps(12102203.2f));
auto b = _mm_castsi128_ps(_mm_cvtps_epi32(a)); // so far as in other variants
// copy 9 MSB from 0x3f800000 over 'b' so that 1 <= c < 2
// - also 1 <= poly_eval(...) < 2
auto c = replace_exponent(b, _mm_set1_ps(1.0f));
auto d = poly_eval(c, kA, kB, kC); // 2nd degree polynomial
auto e = replace_exponent(d, b); // restore exponent : 2^i * 2^f
指数复制可以按位 select 使用适当的掩码完成(AVX-512 有 vpternlogd
,而我实际上使用的是 Arm Neon vbsl
)。
所有的输入值x
必须是负的并且限制在-17-f(N) <= x <= -f(N)之间,这样当按(1<<23)/缩放时log(2),N 个结果浮点值的最大总和不会达到无穷大,并且倒数不会变为非正规。对于 N=3,f(N) = 4。较大的 f(N) 将权衡输入精度。
多值系数由 polyfit([1 1.5 2],[1 sqrt(2) 2])
生成,kA=0.343146,kB=-0.029437,kC=0.68292,生成严格小于 2 的值并防止不连续。通过计算 x=[1+max_err 1.5-eps 2], y=[1 2^(.5-eps) 2-max_err].[ 处的多项式可以减少最大平均误差。 =17=]
对于严格的SSE/AVX,1.0f的指数替换可以用(x & 0x007fffff) | 0x3f800000)
来完成。通过确保 poly_eval(x) 的计算结果为一个范围,可以找到后一个指数替换的两个指令序列,该范围可以直接与 b & 0xff800000
.
进行逻辑运算
为了我的目的,我开发了以下函数,可以快速准确地计算单精度自然指数。该函数适用于整个浮点值范围。代码写在 Visual Studio (x86) 下。使用 AVX 而不是 SSE,但这应该不是问题。此函数的精度几乎与标准 expf 函数相同,但速度明显更快。使用的近似值基于函数 f(t)=t/(2^(t/2)-1)+t/2 的切比雪夫级数展开,其中 t 来自 [-1; 1].感谢 Peter Cordes 的好建议。
_declspec(naked) float _vectorcall fexp(float x)
{
static const float ct[7] = // Constants table
{
1.44269502f, // lb(e)
1.92596299E-8f, // Correction to the value lb(e)
-9.21120925E-4f, // 16*b2
0.115524396f, // 4*b1
2.88539004f, // b0
2.0f, // 2
4.65661287E-10f // 2^-31
};
_asm
{
mov ecx,offset ct // ecx contains the address of constants tables
vmulss xmm1,xmm0,[ecx] // xmm1 = x*lb(e)
vcvtss2si eax,xmm1 // eax = round(x*lb(e)) = k
cdq // edx=-1, if x<0 or overflow, otherwise edx=0
vmovss xmm3,[ecx+8] // Initialize the sum with highest coefficient 16*b2
and edx,4 // edx=4, if x<0 or overflow, otherwise edx=0
vcvtsi2ss xmm1,xmm1,eax // xmm1 = k
lea eax,[eax+8*edx] // Add 32 to exponent, if x<0
vfmsub231ss xmm1,xmm0,[ecx] // xmm1 = x*lb(e)-k = t/2 in the range from -0,5 to 0,5
add eax,126 // The exponent of 2^(k-1) or 2^(k+31) with bias 127
jle exp_low // Jump if x<<0 or overflow (|x| too large or x=NaN)
vfmadd132ss xmm0,xmm1,[ecx+4] // xmm0 = t/2 (corrected value)
cmp eax,254 // Check that the exponent is not too large
jg exp_inf // Jump to set Inf if overflow
vmulss xmm2,xmm0,xmm0 // xmm2 = t^2/4 - the argument of the polynomial
shl eax,23 // The bits of the float value 2^(k-1) or 2^(k+31)
vfmadd213ss xmm3,xmm2,[ecx+12] // xmm3 = 4*b1+4*b2*t^2
vmovd xmm1,eax // xmm1 = 2^(k-1) или 2^(k+31)
vfmsub213ss xmm3,xmm2,xmm0 // xmm3 = -t/2+b1*t^2+b2*t^4
vaddss xmm0,xmm0,xmm0 // xmm0 = t
vaddss xmm3,xmm3,[ecx+16] // xmm3 = b0-t/2+b1*t^2+b2*t^4 = f(t)-t/2
vdivss xmm0,xmm0,xmm3 // xmm0 = t/(f(t)-t/2)
vfmadd213ss xmm0,xmm1,xmm1 // xmm0 = e^x with shifted exponent of -1 or 31
vmulss xmm0,xmm0,[ecx+edx+20] // xmm0 = e^x
ret // Return
exp_low: // Handling the case of x<<0 or overflow
vucomiss xmm0,[ecx] // Check the sign of x and a condition x=NaN
jp exp_end // Complete with NaN result, if x=NaN
exp_inf: // Entry point for processing large x
vxorps xmm0,xmm0,xmm0 // xmm0 = 0
jc exp_end // Ready, if x<<0
vrcpss xmm0,xmm0,xmm0 // xmm0 = Inf in case x>>0
exp_end: // The result at xmm0 is ready
ret // Return
}
}
下面我post一个简化的算法。此处删除了对结果中非规范化数字的支持。
_declspec(naked) float _vectorcall fexp(float x)
{
static const float ct[5] = // Constants table
{
1.44269502f, // lb(e)
1.92596299E-8f, // Correction to the value lb(e)
-9.21120925E-4f, // 16*b2
0.115524396f, // 4*b1
2.88539004f // b0
};
_asm
{
mov edx,offset ct // edx contains the address of constants tables
vmulss xmm1,xmm0,[edx] // xmm1 = x*lb(e)
vcvtss2si eax,xmm1 // eax = round(x*lb(e)) = k
vmovss xmm3,[edx+8] // Initialize the sum with highest coefficient 16*b2
vcvtsi2ss xmm1,xmm1,eax // xmm1 = k
cmp eax,127 // Check that the exponent is not too large
jg exp_break // Jump to set Inf if overflow
vfmsub231ss xmm1,xmm0,[edx] // xmm1 = x*lb(e)-k = t/2 in the range from -0,5 to 0,5
add eax,127 // Receive the exponent of 2^k with the bias 127
jle exp_break // The result is 0, if x<<0
vfmadd132ss xmm0,xmm1,[edx+4] // xmm0 = t/2 (corrected value)
vmulss xmm2,xmm0,xmm0 // xmm2 = t^2/4 - the argument of polynomial
shl eax,23 // eax contains the bits of 2^k
vfmadd213ss xmm3,xmm2,[edx+12] // xmm3 = 4*b1+4*b2*t^2
vmovd xmm1,eax // xmm1 = 2^k
vfmsub213ss xmm3,xmm2,xmm0 // xmm3 = -t/2+b1*t^2+b2*t^4
vaddss xmm0,xmm0,xmm0 // xmm0 = t
vaddss xmm3,xmm3,[edx+16] // xmm3 = b0-t/2+b1*t^2+b2*t^4 = f(t)-t/2
vdivss xmm0,xmm0,xmm3 // xmm0 = t/(f(t)-t/2)
vfmadd213ss xmm0,xmm1,xmm1 // xmm0 = 2^k*(t/(f(t)-t/2)+1) = e^x
ret // Return
exp_break: // Get 0 for x<0 or Inf for x>>0
vucomiss xmm0,[edx] // Check the sign of x and a condition x=NaN
jp exp_end // Complete with NaN result, if x=NaN
vxorps xmm0,xmm0,xmm0 // xmm0 = 0
jc exp_end // Ready, if x<<0
vrcpss xmm0,xmm0,xmm0 // xmm0 = Inf, if x>>0
exp_end: // The result at xmm0 is ready
ret // Return
}
}
我正在寻找对 SSE 元素运行的自然指数函数的近似值。即——__m128 exp( __m128 x )
。
我有一个实现速度很快但准确性似乎很低的实现:
static inline __m128 FastExpSse(__m128 x)
{
__m128 a = _mm_set1_ps(12102203.2f); // (1 << 23) / ln(2)
__m128i b = _mm_set1_epi32(127 * (1 << 23) - 486411);
__m128 m87 = _mm_set1_ps(-87);
// fast exponential function, x should be in [-87, 87]
__m128 mask = _mm_cmpge_ps(x, m87);
__m128i tmp = _mm_add_epi32(_mm_cvtps_epi32(_mm_mul_ps(a, x)), b);
return _mm_and_ps(_mm_castsi128_ps(tmp), mask);
}
任何人都可以实现更准确但速度更快(或更快)的实现吗?
如果它是用 C 风格编写的,我会很高兴。
谢谢。
下面的 C 代码是我在 previous answer 类似问题中使用的算法的 SSE 内在函数的翻译。
基本思想是将标准指数函数的计算转化为2的幂的计算:expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504)
。我们将 t = x * 1.44269504
拆分为一个整数 i
和一个分数 f
,这样 t = i + f
和 0 <= f <= 1
。我们现在可以用多项式近似计算 2f,然后通过将 i
添加到单精度浮点结果。
SSE 实现存在的一个问题是我们想要计算 i = floorf (t)
,但是没有快速的方法来计算 floor()
函数。然而,我们观察到对于正数,floor(x) == trunc(x)
,对于负数,floor(x) == trunc(x) - 1
,除非 x
是负整数。但是,由于核心近似可以处理 1.0f
的 f
值,因此对负参数使用近似是无害的。 SSE提供了一个指令将单精度浮点操作数转换为带截断的整数,所以这个解决方案是高效的。
Peter Cordes指出SSE4.1支持fast floor函数_mm_floor_ps()
,所以下面也展示了一个使用SSE4.1的变体。当启用 SSE 4.1 代码生成时,并非所有工具链都会自动预定义宏 __SSE4_1__
,但 gcc 会。
Compiler Explorer (Godbolt) 显示 gcc 7.2 将以下代码编译为 sixteen instructions for plain SSE and twelve instructions for SSE 4.1。
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <emmintrin.h>
#ifdef __SSE4_1__
#include <smmintrin.h>
#endif
/* max. rel. error = 1.72863156e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, e, p, r;
__m128i i, j;
__m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
__m128 c0 = _mm_set1_ps (0.3371894346f);
__m128 c1 = _mm_set1_ps (0.657636276f);
__m128 c2 = _mm_set1_ps (1.00172476f);
/* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */
t = _mm_mul_ps (x, l2e); /* t = log2(e) * x */
#ifdef __SSE4_1__
e = _mm_floor_ps (t); /* floor(t) */
i = _mm_cvtps_epi32 (e); /* (int)floor(t) */
#else /* __SSE4_1__*/
i = _mm_cvttps_epi32 (t); /* i = (int)t */
j = _mm_srli_epi32 (_mm_castps_si128 (x), 31); /* signbit(t) */
i = _mm_sub_epi32 (i, j); /* (int)t - signbit(t) */
e = _mm_cvtepi32_ps (i); /* floor(t) ~= (int)t - signbit(t) */
#endif /* __SSE4_1__*/
f = _mm_sub_ps (t, e); /* f = t - floor(t) */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
j = _mm_slli_epi32 (i, 23); /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
int main (void)
{
union {
float f[4];
unsigned int i[4];
} arg, res;
double relerr, maxrelerr = 0.0;
int i, j;
__m128 x, y;
float start[2] = {-0.0f, 0.0f};
float finish[2] = {-87.33654f, 88.72283f};
for (i = 0; i < 2; i++) {
arg.f[0] = start[i];
arg.i[1] = arg.i[0] + 1;
arg.i[2] = arg.i[0] + 2;
arg.i[3] = arg.i[0] + 3;
do {
memcpy (&x, &arg, sizeof(x));
y = fast_exp_sse (x);
memcpy (&res, &y, sizeof(y));
for (j = 0; j < 4; j++) {
double ref = exp ((double)arg.f[j]);
relerr = fabs ((res.f[j] - ref) / ref);
if (relerr > maxrelerr) {
printf ("arg=% 15.8e res=%15.8e ref=%15.8e err=%15.8e\n",
arg.f[j], res.f[j], ref, relerr);
maxrelerr = relerr;
}
}
arg.i[0] += 4;
arg.i[1] += 4;
arg.i[2] += 4;
arg.i[3] += 4;
} while (fabsf (arg.f[3]) < fabsf (finish[i]));
}
printf ("maximum relative errror = %15.8e\n", maxrelerr);
return EXIT_SUCCESS;
}
fast_sse_exp()
的另一种设计以舍入到最近的模式提取调整参数 x / log(2)
的整数部分,使用众所周知的添加 "magic" 转换的技术constant 1.5 * 223 强制舍入到正确的位位置,然后再次减去相同的数字。这要求在加法期间有效的 SSE 舍入模式是 "round to nearest or even",这是默认值。 wim在评论中指出,一些编译器可能会在使用激进优化时将转换常量cvt
的加减运算优化为冗余,干扰此代码序列的功能,因此建议检查生成的机器代码。计算 2f 的近似区间现在以零为中心,因为 -0.5 <= f <= 0.5
,需要不同的核心近似。
/* max. rel. error <= 1.72860465e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, p, r;
__m128i i, j;
const __m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
const __m128 cvt = _mm_set1_ps (12582912.0f); /* 1.5 * (1 << 23) */
const __m128 c0 = _mm_set1_ps (0.238428936f);
const __m128 c1 = _mm_set1_ps (0.703448006f);
const __m128 c2 = _mm_set1_ps (1.000443142f);
/* exp(x) = 2^i * 2^f; i = rint (log2(e) * x), -0.5 <= f <= 0.5 */
t = _mm_mul_ps (x, l2e); /* t = log2(e) * x */
r = _mm_sub_ps (_mm_add_ps (t, cvt), cvt); /* r = rint (t) */
f = _mm_sub_ps (t, r); /* f = t - rint (t) */
i = _mm_cvtps_epi32 (t); /* i = (int)t */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= exp2(f) */
j = _mm_slli_epi32 (i, 23); /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
问题中代码的算法似乎取自 Nicol N. Schraudolph 的作品,它巧妙地利用了 IEEE-754 二进制浮点格式的半对数性质:
N. N. Schraudolph. "A fast, compact approximation of the exponential function." 神经计算,11(4),1999 年 5 月,第 853-862 页。
删除参数限制代码后,它减少到只有三个 SSE 指令。 "magical" 校正常数 486411
对于最小化整个输入域的最大相对误差来说不是最佳的。基于简单的二进制搜索,值 298765
似乎更好,将 FastExpSse()
的最大相对误差降低到 3.56e-2,而 fast_exp_sse()
的最大相对误差为 1.73e-3。
/* max. rel. error = 3.55959567e-2 on [-87.33654, 88.72283] */
__m128 FastExpSse (__m128 x)
{
__m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
__m128i b = _mm_set1_epi32 (127 * (1 << 23) - 298765);
__m128i t = _mm_add_epi32 (_mm_cvtps_epi32 (_mm_mul_ps (a, x)), b);
return _mm_castsi128_ps (t);
}
Schraudolph的算法基本上是对[0,1]中的f
使用线性逼近2f~=1.0 + f
,精度有待提高通过添加二次项。 Schraudolph 方法的聪明之处在于计算 2i * 2f 而没有明确地将整数部分 i = floor(x * 1.44269504)
与分数分开。我看不出有什么办法可以将这个技巧扩展到二次近似,但是可以肯定地将 Schraudolph 的 floor()
计算与上面使用的二次近似结合起来:
/* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 f, p, r;
__m128i t, j;
const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
const __m128 c0 = _mm_set1_ps (0.3371894346f);
const __m128 c1 = _mm_set1_ps (0.657636276f);
const __m128 c2 = _mm_set1_ps (1.00172476f);
t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
j = _mm_and_si128 (t, m); /* j = (int)(floor (x/log(2))) << 23 */
t = _mm_sub_epi32 (t, j);
f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
通过使用 FastExpSse(x/2)/FastExpSse(- x/2) 而不是 FastExpSse(x)。这里的技巧是将偏移参数(上面的 298765)设置为零,以便分子和分母中的分段线性近似值对齐,从而消除大量误差。将其合并为一个函数:
__m128 BetterFastExpSse (__m128 x)
{
const __m128 a = _mm_set1_ps ((1 << 22) / float(M_LN2)); // to get exp(x/2)
const __m128i b = _mm_set1_epi32 (127 * (1 << 23)); // NB: zero shift!
__m128i r = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
__m128i s = _mm_add_epi32 (b, r);
__m128i t = _mm_sub_epi32 (b, r);
return _mm_div_ps (_mm_castsi128_ps (s), _mm_castsi128_ps (t));
}
(我不是硬件专家 - 这里的性能杀手有多糟糕?)
如果您需要 exp(x) 只是为了得到 y = tanh(x)(例如对于神经网络),请按如下方式使用具有零偏移的 FastExpSse:
a = FastExpSse(x);
b = FastExpSse(-x);
y = (a - b)/(a + b);
获得相同类型的错误取消福利。 logistic 函数的工作原理类似,使用零偏移的 FastExpSse(x/2)/(FastExpSse(x/2) + FastExpSse(-x/2))。 (这只是为了说明原理——您显然不想在这里多次计算 FastExpSse,而是按照上面 BetterFastExpSse 的方式将其合并为一个函数。)
我确实从中开发了一系列高阶近似值,更准确但也更慢。未发布,但如果有人想试一试,我们很乐意合作。
最后,为了一些乐趣:使用倒档获得 FastLogSse。将其与 FastExpSse 链接在一起可以同时消除运算符和错误,并弹出一个极快的幂函数...
回顾我当时的笔记,我确实探索了在不使用除法的情况下提高准确性的方法。我使用了相同的 reinterpret-as-float 技巧,但对尾数应用了多项式校正,这基本上是用 16 位定点算法计算的(当时唯一的快速计算方法)。
立方四次版本给你 4 resp。准确度的 5 位有效数字。没有必要增加阶数,因为低精度算术的噪声随后开始淹没多项式逼近的误差。以下是纯 C 版本:
#include <stdint.h>
float fastExp3(register float x) // cubic spline approximation
{
union { float f; int32_t i; } reinterpreter;
reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
int32_t m = (reinterpreter.i >> 7) & 0xFFFF; // copy mantissa
// empirical values for small maximum relative error (8.34e-5):
reinterpreter.i +=
((((((((1277*m) >> 14) + 14825)*m) >> 14) - 79749)*m) >> 11) - 626;
return reinterpreter.f;
}
float fastExp4(register float x) // quartic spline approximation
{
union { float f; int32_t i; } reinterpreter;
reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
int32_t m = (reinterpreter.i >> 7) & 0xFFFF; // copy mantissa
// empirical values for small maximum relative error (1.21e-5):
reinterpreter.i += (((((((((((3537*m) >> 16)
+ 13668)*m) >> 18) + 15817)*m) >> 14) - 80470)*m) >> 11);
return reinterpreter.f;
}
四次方服从 (fastExp4(0f) == 1f),这对于定点迭代算法很重要。
SSE 中这些整数乘移加序列的效率如何?在浮点算术同样快的体系结构上,可以使用它来代替,从而减少算术噪声。这基本上会产生上面@njuffa 的答案的三次和四次扩展。
有一篇关于创建这些方程(tanh、cosh、artanh、sinh 等)的快速版本的论文:
http://ijeais.org/wp-content/uploads/2018/07/IJAER180702.pdf "Creating a Compiler Optimized Inlineable Implementation of Intel Svml Simd Intrinsics"
他们第 9 页的 tanh 方程 6 与@NicSchraudolph 的回答非常相似
对于 softmax 的使用,我将流程设想为:
auto a = _mm_mul_ps(x, _mm_set1_ps(12102203.2f));
auto b = _mm_castsi128_ps(_mm_cvtps_epi32(a)); // so far as in other variants
// copy 9 MSB from 0x3f800000 over 'b' so that 1 <= c < 2
// - also 1 <= poly_eval(...) < 2
auto c = replace_exponent(b, _mm_set1_ps(1.0f));
auto d = poly_eval(c, kA, kB, kC); // 2nd degree polynomial
auto e = replace_exponent(d, b); // restore exponent : 2^i * 2^f
指数复制可以按位 select 使用适当的掩码完成(AVX-512 有 vpternlogd
,而我实际上使用的是 Arm Neon vbsl
)。
所有的输入值x
必须是负的并且限制在-17-f(N) <= x <= -f(N)之间,这样当按(1<<23)/缩放时log(2),N 个结果浮点值的最大总和不会达到无穷大,并且倒数不会变为非正规。对于 N=3,f(N) = 4。较大的 f(N) 将权衡输入精度。
多值系数由 polyfit([1 1.5 2],[1 sqrt(2) 2])
生成,kA=0.343146,kB=-0.029437,kC=0.68292,生成严格小于 2 的值并防止不连续。通过计算 x=[1+max_err 1.5-eps 2], y=[1 2^(.5-eps) 2-max_err].[ 处的多项式可以减少最大平均误差。 =17=]
对于严格的SSE/AVX,1.0f的指数替换可以用(x & 0x007fffff) | 0x3f800000)
来完成。通过确保 poly_eval(x) 的计算结果为一个范围,可以找到后一个指数替换的两个指令序列,该范围可以直接与 b & 0xff800000
.
为了我的目的,我开发了以下函数,可以快速准确地计算单精度自然指数。该函数适用于整个浮点值范围。代码写在 Visual Studio (x86) 下。使用 AVX 而不是 SSE,但这应该不是问题。此函数的精度几乎与标准 expf 函数相同,但速度明显更快。使用的近似值基于函数 f(t)=t/(2^(t/2)-1)+t/2 的切比雪夫级数展开,其中 t 来自 [-1; 1].感谢 Peter Cordes 的好建议。
_declspec(naked) float _vectorcall fexp(float x)
{
static const float ct[7] = // Constants table
{
1.44269502f, // lb(e)
1.92596299E-8f, // Correction to the value lb(e)
-9.21120925E-4f, // 16*b2
0.115524396f, // 4*b1
2.88539004f, // b0
2.0f, // 2
4.65661287E-10f // 2^-31
};
_asm
{
mov ecx,offset ct // ecx contains the address of constants tables
vmulss xmm1,xmm0,[ecx] // xmm1 = x*lb(e)
vcvtss2si eax,xmm1 // eax = round(x*lb(e)) = k
cdq // edx=-1, if x<0 or overflow, otherwise edx=0
vmovss xmm3,[ecx+8] // Initialize the sum with highest coefficient 16*b2
and edx,4 // edx=4, if x<0 or overflow, otherwise edx=0
vcvtsi2ss xmm1,xmm1,eax // xmm1 = k
lea eax,[eax+8*edx] // Add 32 to exponent, if x<0
vfmsub231ss xmm1,xmm0,[ecx] // xmm1 = x*lb(e)-k = t/2 in the range from -0,5 to 0,5
add eax,126 // The exponent of 2^(k-1) or 2^(k+31) with bias 127
jle exp_low // Jump if x<<0 or overflow (|x| too large or x=NaN)
vfmadd132ss xmm0,xmm1,[ecx+4] // xmm0 = t/2 (corrected value)
cmp eax,254 // Check that the exponent is not too large
jg exp_inf // Jump to set Inf if overflow
vmulss xmm2,xmm0,xmm0 // xmm2 = t^2/4 - the argument of the polynomial
shl eax,23 // The bits of the float value 2^(k-1) or 2^(k+31)
vfmadd213ss xmm3,xmm2,[ecx+12] // xmm3 = 4*b1+4*b2*t^2
vmovd xmm1,eax // xmm1 = 2^(k-1) или 2^(k+31)
vfmsub213ss xmm3,xmm2,xmm0 // xmm3 = -t/2+b1*t^2+b2*t^4
vaddss xmm0,xmm0,xmm0 // xmm0 = t
vaddss xmm3,xmm3,[ecx+16] // xmm3 = b0-t/2+b1*t^2+b2*t^4 = f(t)-t/2
vdivss xmm0,xmm0,xmm3 // xmm0 = t/(f(t)-t/2)
vfmadd213ss xmm0,xmm1,xmm1 // xmm0 = e^x with shifted exponent of -1 or 31
vmulss xmm0,xmm0,[ecx+edx+20] // xmm0 = e^x
ret // Return
exp_low: // Handling the case of x<<0 or overflow
vucomiss xmm0,[ecx] // Check the sign of x and a condition x=NaN
jp exp_end // Complete with NaN result, if x=NaN
exp_inf: // Entry point for processing large x
vxorps xmm0,xmm0,xmm0 // xmm0 = 0
jc exp_end // Ready, if x<<0
vrcpss xmm0,xmm0,xmm0 // xmm0 = Inf in case x>>0
exp_end: // The result at xmm0 is ready
ret // Return
}
}
下面我post一个简化的算法。此处删除了对结果中非规范化数字的支持。
_declspec(naked) float _vectorcall fexp(float x)
{
static const float ct[5] = // Constants table
{
1.44269502f, // lb(e)
1.92596299E-8f, // Correction to the value lb(e)
-9.21120925E-4f, // 16*b2
0.115524396f, // 4*b1
2.88539004f // b0
};
_asm
{
mov edx,offset ct // edx contains the address of constants tables
vmulss xmm1,xmm0,[edx] // xmm1 = x*lb(e)
vcvtss2si eax,xmm1 // eax = round(x*lb(e)) = k
vmovss xmm3,[edx+8] // Initialize the sum with highest coefficient 16*b2
vcvtsi2ss xmm1,xmm1,eax // xmm1 = k
cmp eax,127 // Check that the exponent is not too large
jg exp_break // Jump to set Inf if overflow
vfmsub231ss xmm1,xmm0,[edx] // xmm1 = x*lb(e)-k = t/2 in the range from -0,5 to 0,5
add eax,127 // Receive the exponent of 2^k with the bias 127
jle exp_break // The result is 0, if x<<0
vfmadd132ss xmm0,xmm1,[edx+4] // xmm0 = t/2 (corrected value)
vmulss xmm2,xmm0,xmm0 // xmm2 = t^2/4 - the argument of polynomial
shl eax,23 // eax contains the bits of 2^k
vfmadd213ss xmm3,xmm2,[edx+12] // xmm3 = 4*b1+4*b2*t^2
vmovd xmm1,eax // xmm1 = 2^k
vfmsub213ss xmm3,xmm2,xmm0 // xmm3 = -t/2+b1*t^2+b2*t^4
vaddss xmm0,xmm0,xmm0 // xmm0 = t
vaddss xmm3,xmm3,[edx+16] // xmm3 = b0-t/2+b1*t^2+b2*t^4 = f(t)-t/2
vdivss xmm0,xmm0,xmm3 // xmm0 = t/(f(t)-t/2)
vfmadd213ss xmm0,xmm1,xmm1 // xmm0 = 2^k*(t/(f(t)-t/2)+1) = e^x
ret // Return
exp_break: // Get 0 for x<0 or Inf for x>>0
vucomiss xmm0,[edx] // Check the sign of x and a condition x=NaN
jp exp_end // Complete with NaN result, if x=NaN
vxorps xmm0,xmm0,xmm0 // xmm0 = 0
jc exp_end // Ready, if x<<0
vrcpss xmm0,xmm0,xmm0 // xmm0 = Inf, if x>>0
exp_end: // The result at xmm0 is ready
ret // Return
}
}