我想使用 NTT 进行快速平方(请参阅快速 bignum 平方计算),但即使对于非常大的数字……超过 12000 位.
所以我的问题是:
- 有没有办法优化我的 NTT 转换?我并不是要通过并行(线程)来加速它;这只是低级层.
- 有没有办法加快我的模块化算术的速度?
这是我在 C++ 中为 NTT 编写的(已经优化的)源代码(它是完整的并且 100% 在 C++ 中工作,不需要第三方库,并且还应该是线程安全的.注意源数组被用作临时!!!,也不能将数组转化为自身).
//--------------------------------------------------------------------------- class fourier_NTT // Number theoretic transform { public: DWORD r,L,p,N; DWORD W,iW,rN; fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; } // main interface void NTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n]) void INTT(DWORD *dst,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n]) // Helper functions bool init(DWORD n); // init r,W,rN void NTT_fast(DWORD *dst,DWORD n,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n]) // Only for testing void NTT_slow(DWORD *dst,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n]) void INTT_slow(DWORD *dst,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n]) // DWORD arithmetics DWORD shl(DWORD a); DWORD shr(DWORD a); // Modular arithmetics DWORD mod(DWORD a); DWORD modadd(DWORD a,DWORD b); DWORD modsub(DWORD a,DWORD b); DWORD modmul(DWORD a,DWORD b); DWORD modpow(DWORD a,DWORD b); }; //--------------------------------------------------------------------------- void fourier_NTT:: NTT(DWORD *dst,DWORD n) { if (n>0) init(n); NTT_fast(dst,src,N,W); // NTT_slow(dst,W); } //--------------------------------------------------------------------------- void fourier_NTT::INTT(DWORD *dst,iW); for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN); // INTT_slow(dst,W); } //--------------------------------------------------------------------------- bool fourier_NTT::init(DWORD n) { // (max(src[])^2)*n < p else NTT overflow can ocur !!! r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit // r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit // r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit // r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit N=n; // size of vectors [DWORDs] W=modpow(r,L); // Wn for NTT iW=modpow(r,p-1-L); // Wn for INTT rN=modpow(n,p-2 ); // scale for INTT return true; } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_fast(DWORD *dst,DWORD w) { if (n<=1) { if (n==1) dst[0]=src[0]; return; } DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w); // reorder even,odd for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j]; for ( j=1;i<n ;i++,j+=2) dst[i]=src[j]; // recursion NTT_fast(src,dst,n2,w2); // even NTT_fast(src+n2,dst+n2,w2); // odd // restore results for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w)) { a0=src[i]; a1=modmul(src[j],w2); dst[i]=modadd(a0,a1); dst[j]=modsub(a0,a1); } } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_slow(DWORD *dst,DWORD w) { DWORD i,wj,wi,a,n2=n>>1; for (wj=1,j=0;j<n;j++) { a=0; for (wi=1,i=0;i<n;i++) { a=modadd(a,modmul(wi,src[i])); wi=modmul(wi,wj); } dst[j]=a; wj=modmul(wj,w); } } //--------------------------------------------------------------------------- void fourier_NTT::INTT_slow(DWORD *dst,wi=1,wj=1,wj); } dst[j]=modmul(a,rN); wj=modmul(wj,iW); } } //--------------------------------------------------------------------------- DWORD fourier_NTT::shl(DWORD a) { return (a<<1)&0xFFFFFFFE; } DWORD fourier_NTT::shr(DWORD a) { return (a>>1)&0x7FFFFFFF; } //--------------------------------------------------------------------------- DWORD fourier_NTT::mod(DWORD a) { DWORD bb; for (bb=p;(DWORD(a)>DWORD(bb))&&(!DWORD(bb&0x80000000));bb=shl(bb)); for (;;) { if (DWORD(a)>=DWORD(bb)) a-=bb; if (bb==p) break; bb =shr(bb); } return a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modadd(DWORD a,DWORD b) { DWORD d,cy; a=mod(a); b=mod(b); d=a+b; cy=(shr(a)+shr(b)+shr((a&1)+(b&1)))&0x80000000; if (cy) d-=p; if (DWORD(d)>=DWORD(p)) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modsub(DWORD a,DWORD b) { DWORD d; a=mod(a); b=mod(b); d=a-b; if (DWORD(a)<DWORD(b)) d+=p; if (DWORD(d)>=DWORD(p)) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modmul(DWORD a,DWORD b) { // b bez orezania ! int i; DWORD d; a=mod(a); for (d=0,i=0;i<32;i++) { if (DWORD(a&1)) d=modadd(d,b); a=shr(a); b=modadd(b,b); } return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modpow(DWORD a,DWORD b) { // a,b bez orezania ! int i; DWORD d=1; for (i=0;i<32;i++) { d=modmul(d,d); if (DWORD(b&0x80000000)) d=modmul(d,a); b=shl(b); } return d; } //---------------------------------------------------------------------------
我的 NTT 类的使用示例:
fourier_NTT ntt; const DWORD n=32 DWORD x[N]={0,1,2,3,....31},y[N]={32,33,34,35,...63},z[N]; ntt.NTT(z,x,N); // z[N]=NTT(x[N]),also init constants for N ntt.NTT(x,y); // x[N]=NTT(y[N]),no recompute of constants,use last N // modular convolution y[]=z[].x[] for (i=0;i<n;i++) y[i]=ntt.modmul(z[i],x[i]); ntt.INTT(x,y); // x[N]=INTT(y[N]),use last N // x[]=convolution of original x[].y[]
优化前的一些测量(非 NTT 类):
a = 0.98765588997654321000 | 389*32 bits looped 1x times sqr1[ 3.177 ms ] fast sqr sqr2[ 720.419 ms ] NTT sqr mul1[ 5.588 ms ] simpe mul mul2[ 3.172 ms ] karatsuba mul mul3[ 1053.382 ms ] NTT mul
我优化后的一些测量(当前代码、较低的递归参数大小/计数和更好的模块化算法):
a = 0.98765588997654321000 | 389*32 bits looped 1x times sqr1[ 3.214 ms ] fast sqr sqr2[ 208.298 ms ] NTT sqr mul1[ 5.564 ms ] simpe mul mul2[ 3.113 ms ] karatsuba mul mul3[ 302.740 ms ] NTT mul
检查 NTT mul 和 NTT sqr 时间(我的优化加快了 3 倍多一点).它只有 1 倍循环,所以它不是很精确(误差 ~ 10%),但即使现在加速也很明显(通常我循环它 1000 倍甚至更多,但我的 NTT 太慢了).
您可以自由使用我的代码...只需将我的昵称和/或指向此页面的链接保留在某处(rem in code、readme.txt、about 或其他内容).我希望它有帮助......(我没有在任何地方看到快速 NTT 的 C++ 源代码,所以我不得不自己编写).对所有接受的 N 都测试了统一根,请参阅 fourier_NTT::init(DWORD n) 函数.
P.S.:有关 NTT 的更多信息,请参阅从复杂 FFT 到有限域 FFT 的转换.此代码基于我在该链接中的帖子.
[edit1:] 代码的进一步变化
通过利用模素数始终为 0xC0000001 并消除不必要的调用,我设法进一步优化了我的模算术.由此产生的加速现在令人惊叹(超过 40 倍),并且在大约 1500 * 32 位阈值之后,NTT 乘法比 karatsuba 更快.顺便说一句,我的 NTT 的速度现在与我在 64 位双打上优化的 DFFT 的速度相同.
一些测量:
a = 0.98765588997654321000 | 1553*32bits looped 10x times mul2[ 28.585 ms ] karatsuba mul mul3[ 26.311 ms ] NTT mul
模块化算术的新源代码:
//--------------------------------------------------------------------------- DWORD fourier_NTT::mod(DWORD a) { if (a>p) a-=p; return a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modadd(DWORD a,cy; if (a>p) a-=p; if (b>p) b-=p; d=a+b; cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000; if (cy ) d-=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modsub(DWORD a,DWORD b) { DWORD d; if (a>p) a-=p; if (b>p) b-=p; d=a-b; if (a<b) d+=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modmul(DWORD a,DWORD b) { DWORD _a,_b,_p; _a=a; _b=b; _p=p; asm { mov eax,_a mov ebx,_b mul ebx // H(edx),L(eax) = eax * ebx mov ebx,_p div ebx // eax = H(edx),L(eax) / ebx mov _a,edx // edx = H(edx),L(eax) % ebx } return _a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modpow(DWORD a,DWORD b) { // b bez orezania! int i; DWORD d=1; if (a>p) a-=p; for (i=0;i<32;i++) { d=modmul(d,a); b<<=1; } return d; } //---------------------------------------------------------------------------
如您所见,函数 shl 和 shr 不再使用.我认为 modpow 可以进一步优化,但它不是一个关键函数,因为它只被调用了很少的次数.最关键的函数是 modmul,它似乎处于最佳状态.
其他问题:
- 还有其他方法可以加速 NTT 吗?
- 我对模块化算法的优化安全吗?(结果似乎是一样的,但我可能会遗漏一些东西.)
[edit2] 新的优化
a = 0.99991970486 | 2000*32 bits looped 10x sqr1[ 13.908 ms ] fast sqr sqr2[ 13.649 ms ] NTT sqr mul1[ 19.726 ms ] simpe mul mul2[ 31.808 ms ] karatsuba mul mul3[ 19.373 ms ] NTT mul
我从你的所有评论中实现了所有可用的东西(感谢你的洞察力).
加速:
- 通过移除不必要的安全模组(Mandalf The Beige)+2.5%
- +34.9% 通过使用预先计算的 W,iW 功率(神秘)
- +35% 总计
实际完整源代码:
//--------------------------------------------------------------------------- //--- Number theoretic transforms: 2.03 ------------------------------------- //--------------------------------------------------------------------------- #ifndef _fourier_NTT_h #define _fourier_NTT_h //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- class fourier_NTT // Number theoretic transform { public: DWORD r,rN; // W=(r^L) mod p,iW=inverse W,rN = inverse N DWORD *WW,*iWW,NN; // Precomputed (W,iW)^(0,..,NN-1) powers // Internals fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; WW=NULL; iWW=NULL; NN=0; } ~fourier_NTT(){ _free(); } void _free(); // Free precomputed W,iW powers tables void _alloc(DWORD n); // Allocate and precompute W,iW powers tables // Main interface void NTT(DWORD *dst,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n]) void iNTT(DWORD *dst,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n]) // Helper functions bool init(DWORD n); // init r,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n]) void NTT_fast(DWORD *dst,DWORD *w2,DWORD i2); // Only for testing void NTT_slow(DWORD *dst,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n]) void iNTT_slow(DWORD *dst,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n]) // Modular arithmetics (optimized,but it works only for p >= 0x80000000!!!) DWORD mod(DWORD a); DWORD modadd(DWORD a,DWORD b); }; //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- void fourier_NTT::_free() { NN=0; if ( WW) delete[] WW; WW=NULL; if (iWW) delete[] iWW; iWW=NULL; } //--------------------------------------------------------------------------- void fourier_NTT::_alloc(DWORD n) { if (n<=NN) return; DWORD *tmp,i,w; tmp=new DWORD[n]; if ((NN)&&( WW)) for (i=0;i<NN;i++) tmp[i]= WW[i]; if ( WW) delete[] WW; WW=tmp; WW[0]=1; for (i=NN?NN:1,w= WW[i-1];i<n;i++){ w=modmul(w,W); WW[i]=w; } tmp=new DWORD[n]; if ((NN)&&(iWW)) for (i=0;i<NN;i++) tmp[i]=iWW[i]; if (iWW) delete[] iWW; iWW=tmp; iWW[0]=1; for (i=NN?NN:1,w=iWW[i-1];i<n;i++){ w=modmul(w,iW); iWW[i]=w; } NN=n; } //--------------------------------------------------------------------------- void fourier_NTT:: NTT(DWORD *dst,WW,1); // NTT_fast(dst,W); } //--------------------------------------------------------------------------- void fourier_NTT::iNTT(DWORD *dst,iWW,rN); // iNTT_slow(dst,W); } //--------------------------------------------------------------------------- bool fourier_NTT::init(DWORD n) { // (max(src[])^2)*n < p else NTT overflow can ocur!!! r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit // r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit // r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit // r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit N=n; // Size of vectors [DWORDs] W=modpow(r,L); // Wn for NTT iW=modpow(r,p-1-L); // Wn for INTT rN=modpow(n,p-2 ); // Scale for INTT _alloc(n>>1); // Precompute W,iW powers return true; } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_fast(DWORD *dst,w); // Reorder even,j+=2) dst[i]=src[j]; // Recursion NTT_fast(src,w2); // Even NTT_fast(src+n2,w2); // Odd // Restore results for (w2=1,a1); } } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_fast(DWORD *dst,DWORD i2) { if (n<=1) { if (n==1) dst[0]=src[0]; return; } DWORD i,n2=n>>1; // Reorder even,j+=2) dst[i]=src[j]; // Recursion i=i2<<1; NTT_fast(src,w2,i); // Even NTT_fast(src+n2,i); // Odd // Restore results for (i=0,w2+=i2) { a0=src[i]; a1=modmul(src[j],*w2); dst[i]=modadd(a0,a; for (wj=1,w); } } //--------------------------------------------------------------------------- void fourier_NTT::iNTT_slow(DWORD *dst,iW); } } //--------------------------------------------------------------------------- DWORD fourier_NTT::mod(DWORD a) { if (a>p) a-=p; return a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modadd(DWORD a,cy; //if (a>p) a-=p; //if (b>p) b-=p; d=a+b; cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000; if (cy ) d-=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modsub(DWORD a,DWORD b) { DWORD d; //if (a>p) a-=p; //if (b>p) b-=p; d=a-b; if (a<b) d+=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modmul(DWORD a,DWORD b) { // b is not mod(p)! int i; DWORD d=1; //if (a>p) a-=p; for (i=0;i<32;i++) { d=modmul(d,a); b<<=1; } return d; } //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- #endif //--------------------------------------------------------------------------- //---------------------------------------------------------------------------
通过将 NTT_fast 分成两个函数,仍然有可能使用更少的堆垃圾.一个带有 WW[],另一个带有 iWW[],这导致递归调用中的一个参数减少.但我对它的期望并不高(仅限 32 位指针),而是有一个功能可以在未来更好地管理代码.许多函数现在处于休眠状态(用于测试),例如慢变体、mod 和较旧的快速函数(使用 w 参数代替 *w2,i2).
为避免大数据集溢出,将输入数字限制为 p/4 位. 其中 p 是每个 的位数NTT 元素,因此对于此 32 位版本,请使用最大 (32 位/4 -> 8 位) 输入值.
[edit3] 用于测试的简单字符串 bigint 乘法
//--------------------------------------------------------------------------- char* mul_NTT(const char *sx,const char *sy) { char *s; int i,k,n; // n = min power of 2 <= 2 max length(x,y) for (i=0;sx[i];i++); for (n=1;n<i;n<<=1); i--; for (j=0;sx[j];j++); for (n=1;n<j;n<<=1); n<<=1; j--; DWORD *x,*y,*xx,*yy,a; x=new DWORD[n]; xx=new DWORD[n]; y=new DWORD[n]; yy=new DWORD[n]; // Zero padding for (k=0;i>=0;i--,k++) x[k]=sx[i]-'0'; for (;k<n;k++) x[k]=0; for (k=0;j>=0;j--,k++) y[k]=sy[j]-'0'; for (;k<n;k++) y[k]=0; //NTT fourier_NTT ntt; ntt.NTT(xx,n); ntt.NTT(yy,y); // Convolution for (i=0;i<n;i++) xx[i]=ntt.modmul(xx[i],yy[i]); //INTT ntt.iNTT(yy,xx); //suma a=0; s=new char[n+1]; for (i=0;i<n;i++) { a+=yy[i]; s[n-i-1]=(a%10)+'0'; a/=10; } s[n]=0; delete[] x; delete[] xx; delete[] y; delete[] yy; return s; } //---------------------------------------------------------------------------
我使用AnsiString,所以我希望将它移植到char*,我没有做错.看起来它工作正常(与 AnsiString 版本相比).
- sx,sy 是十进制整数
- 返回分配的字符串(char*)=sx*sy
- sx,sy are decadic integer numbers
- Returns allocated string (char*)=sx*sy
每 32 位数据字只有 ~4 位,因此没有溢出的风险,但当然速度较慢.在我的 bignum 库中,我使用二进制表示,并为 NTT 每 32 位 WORD 使用 8 位 块.如果 N 很大,那么风险更大......
玩得开心