模块化算法和 NTT(有限域 DFT)优化

如何解决模块化算法和 NTT(有限域 DFT)优化?

开发过程中遇到模块化算法和 NTT(有限域 DFT)优化的问题如何解决?下面主要结合日常开发的经验,给出你关于模块化算法和 NTT(有限域 DFT)优化的解决方法建议,希望对你解决模块化算法和 NTT(有限域 DFT)优化有所启发或帮助;

问题描述

我想使用 NTT 进行快速平方(请参阅快速 bignum 平方计算),但即使对于非常大的数字……超过 12000 位.

所以我的问题是:

  1. 有没有办法优化我的 NTT 转换?我并不是要通过并行(线程)来加速它;这只是低级层.
  2. 有没有办法加快我的模块化算术的速度?

这是我在 C++ 中为 NTT 编写的(已经优化的)源代码(它是完整的并且 100% 在 C++ 中工作,不需要第三方库,并且还应该是线程安全的.注意源数组被用作临时!!!,也不能将数组转化为自身).

//---------------------------------------------------------------------------
class fourier_NTT                                    // Number theoretic transform
    {

public:
    DWORD r,L,p,N;
    DWORD W,iW,rN;
    fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; }

    // main interface
    void  NTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast  NTT(DWORD src[n])
    void INTT(DWORD *dst,DWORD n=0);               // DWORD dst[n] = fast INTT(DWORD src[n])

    // Helper functions
    bool init(DWORD n);                                       // init r,W,rN
    void  NTT_fast(DWORD *dst,DWORD n,DWORD w);    // DWORD dst[n] = fast  NTT(DWORD src[n])

    // Only for testing
    void  NTT_slow(DWORD *dst,DWORD w);    // DWORD dst[n] = slow  NTT(DWORD src[n])
    void INTT_slow(DWORD *dst,DWORD w);    // DWORD dst[n] = slow INTT(DWORD src[n])

    // DWORD arithmetics
    DWORD shl(DWORD a);
    DWORD shr(DWORD a);

    // Modular arithmetics
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    DWORD modsub(DWORD a,DWORD b);
    DWORD modmul(DWORD a,DWORD b);
    DWORD modpow(DWORD a,DWORD b);
    };

//---------------------------------------------------------------------------
void fourier_NTT:: NTT(DWORD *dst,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,W);
//    NTT_slow(dst,W);
    }

//---------------------------------------------------------------------------
void fourier_NTT::INTT(DWORD *dst,iW);
    for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
       //    INTT_slow(dst,W);
    }

//---------------------------------------------------------------------------
bool fourier_NTT::init(DWORD n)
    {
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
//    r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
//    r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
//    r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
     N=n;                // size of vectors [DWORDs]
     W=modpow(r,L);    // Wn for NTT
    iW=modpow(r,p-1-L);    // Wn for INTT
    rN=modpow(n,p-2  );    // scale for INTT
    return true;
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD w)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);
    // reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];
    // recursion
    NTT_fast(src,dst,n2,w2);    // even
    NTT_fast(src+n2,dst+n2,w2);    // odd
    // restore results
    for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
        {
        a0=src[i];
        a1=modmul(src[j],w2);
        dst[i]=modadd(a0,a1);
        dst[j]=modsub(a0,a1);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_slow(DWORD *dst,DWORD w)
    {
    DWORD i,wj,wi,a,n2=n>>1;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i]));
            wi=modmul(wi,wj);
            }
        dst[j]=a;
        wj=modmul(wj,w);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT::INTT_slow(DWORD *dst,wi=1,wj=1,wj);
            }
        dst[j]=modmul(a,rN);
        wj=modmul(wj,iW);
        }
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::shl(DWORD a) { return (a<<1)&0xFFFFFFFE; }
DWORD fourier_NTT::shr(DWORD a) { return (a>>1)&0x7FFFFFFF; }

//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
    {
    DWORD bb;
    for (bb=p;(DWORD(a)>DWORD(bb))&&(!DWORD(bb&0x80000000));bb=shl(bb));
    for (;;)
        {
        if (DWORD(a)>=DWORD(bb)) a-=bb;
        if (bb==p) break;
        bb =shr(bb);
        }
    return a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    a=mod(a);
    b=mod(b);
    d=a+b;
    cy=(shr(a)+shr(b)+shr((a&1)+(b&1)))&0x80000000;
    if (cy) d-=p;
    if (DWORD(d)>=DWORD(p)) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    a=mod(a);
    b=mod(b);
    d=a-b; if (DWORD(a)<DWORD(b)) d+=p;
    if (DWORD(d)>=DWORD(p)) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {    // b bez orezania !
    int i;
    DWORD d;
    a=mod(a);
    for (d=0,i=0;i<32;i++)
        {
        if (DWORD(a&1))    d=modadd(d,b);
        a=shr(a);
        b=modadd(b,b);
        }
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    {    // a,b bez orezania !
    int i;
    DWORD d=1;
    for (i=0;i<32;i++)
        {
        d=modmul(d,d);
        if (DWORD(b&0x80000000)) d=modmul(d,a);
        b=shl(b);
        }
    return d;
    }
//---------------------------------------------------------------------------

我的 NTT 类的使用示例:

fourier_NTT ntt;
const DWORD n=32
DWORD x[N]={0,1,2,3,....31},y[N]={32,33,34,35,...63},z[N];

ntt.NTT(z,x,N);    // z[N]=NTT(x[N]),also init constants for N
ntt.NTT(x,y);    // x[N]=NTT(y[N]),no recompute of constants,use last N
// modular convolution y[]=z[].x[]
for (i=0;i<n;i++) y[i]=ntt.modmul(z[i],x[i]);
ntt.INTT(x,y);    // x[N]=INTT(y[N]),use last N
// x[]=convolution of original x[].y[]

优化前的一些测量(非 NTT 类):

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul

我优化后的一些测量(当前代码、较低的递归参数大小/计数和更好的模块化算法):

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.214 ms ] fast sqr
sqr2[ 208.298 ms ] NTT sqr
mul1[ 5.564 ms ] simpe mul
mul2[ 3.113 ms ] karatsuba mul
mul3[ 302.740 ms ] NTT mul

检查 NTT mul 和 NTT sqr 时间(我的优化加快了 3 倍多一点).它只有 1 倍循环,所以它不是很精确(误差 ~ 10%),但即使现在加速也很明显(通常我循环它 1000 倍甚至更多,但我的 NTT 太慢了).

您可以自由使用我的代码...只需将我的昵称和/或指向此页面的链接保留在某处(rem in code、readme.txt、about 或其他内容).我希望它有帮助......(我没有在任何地方看到快速 NTT 的 C++ 源代码,所以我不得不自己编写).对所有接受的 N 都测试了统一根,请参阅 fourier_NTT::init(DWORD n) 函数.

P.S.:有关 NTT 的更多信息,请参阅从复杂 FFT 到有限域 FFT 的转换.此代码基于我在该链接中的帖子.

[edit1:] 代码的进一步变化

通过利用模素数始终为 0xC0000001 并消除不必要的调用,我设法进一步优化了我的模算术.由此产生的加速现在令人惊叹(超过 40 倍),并且在大约 1500 * 32 位阈值之后,NTT 乘法比 karatsuba 更快.顺便说一句,我的 NTT 的速度现在与我在 64 位双打上优化的 DFFT 的速度相同.

一些测量:

a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] karatsuba mul
mul3[ 26.311 ms ] NTT mul

模块化算术的新源代码:

//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
    {
    if (a>p) a-=p;
    return a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,cy;
    if (a>p) a-=p;
    if (b>p) b-=p;
    d=a+b;
    cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
    if (cy ) d-=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    if (a>p) a-=p;
    if (b>p) b-=p;
    d=a-b;
    if (a<b) d+=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {
    DWORD _a,_b,_p;
    _a=a;
    _b=b;
    _p=p;
    asm    {
        mov    eax,_a
        mov    ebx,_b
        mul    ebx        // H(edx),L(eax) = eax * ebx
        mov    ebx,_p
        div    ebx        // eax = H(edx),L(eax) / ebx
        mov    _a,edx    // edx = H(edx),L(eax) % ebx
        }
    return _a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    {    // b bez orezania!
    int i;
    DWORD d=1;
    if (a>p) a-=p;
    for (i=0;i<32;i++)
        {
        d=modmul(d,a);
        b<<=1;
        }
    return d;
    }

//---------------------------------------------------------------------------

如您所见,函数 shl 和 shr​​ 不再使用.我认为 modpow 可以进一步优化,但它不是一个关键函数,因为它只被调用了很少的次数.最关键的函数是 modmul,它似乎处于最佳状态.

其他问题:

  • 还有其他方法可以加速 NTT 吗?
  • 我对模块化算法的优化安全吗?(结果似乎是一样的,但我可能会遗漏一些东西.)

[edit2] 新的优化

a = 0.99991970486 | 2000*32 bits
looped 10x
sqr1[  13.908 ms ] fast sqr
sqr2[  13.649 ms ] NTT sqr
mul1[  19.726 ms ] simpe mul
mul2[  31.808 ms ] karatsuba mul
mul3[  19.373 ms ] NTT mul

我从你的所有评论中实现了所有可用的东西(感谢你的洞察力).

加速:

  • 通过移除不必要的安全模组(Mandalf The Beige)+2.5%
  • +34.9% 通过使用预先计算的 W,iW 功率(神秘)
  • +35% 总计

实际完整源代码:

//---------------------------------------------------------------------------
//--- Number theoretic transforms: 2.03 -------------------------------------
//---------------------------------------------------------------------------
#ifndef _fourier_NTT_h
#define _fourier_NTT_h
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
class fourier_NTT        // Number theoretic transform
    {
public:
    DWORD r,rN;        // W=(r^L) mod p,iW=inverse W,rN = inverse N
    DWORD *WW,*iWW,NN;    // Precomputed (W,iW)^(0,..,NN-1) powers

    // Internals
    fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; WW=NULL; iWW=NULL; NN=0; }
    ~fourier_NTT(){ _free(); }
    void _free();                                            // Free precomputed W,iW powers tables
    void _alloc(DWORD n);                                    // Allocate and precompute W,iW powers tables

    // Main interface
    void  NTT(DWORD *dst,DWORD n=0);                // DWORD dst[n] = fast  NTT(DWORD src[n])
    void iNTT(DWORD *dst,DWORD n=0);               // DWORD dst[n] = fast INTT(DWORD src[n])

    // Helper functions
    bool init(DWORD n);                                          // init r,DWORD w);    // DWORD dst[n] = fast  NTT(DWORD src[n])
    void  NTT_fast(DWORD *dst,DWORD *w2,DWORD i2);

    // Only for testing
    void  NTT_slow(DWORD *dst,DWORD w);    // DWORD dst[n] = slow  NTT(DWORD src[n])
    void iNTT_slow(DWORD *dst,DWORD w);    // DWORD dst[n] = slow INTT(DWORD src[n])

    // Modular arithmetics (optimized,but it works only for p >= 0x80000000!!!)
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    };
//---------------------------------------------------------------------------

//---------------------------------------------------------------------------
void fourier_NTT::_free()
    {
    NN=0;
    if ( WW) delete[]  WW;  WW=NULL;
    if (iWW) delete[] iWW; iWW=NULL;
    }

//---------------------------------------------------------------------------
void fourier_NTT::_alloc(DWORD n)
    {
    if (n<=NN) return;
    DWORD *tmp,i,w;
    tmp=new DWORD[n]; if ((NN)&&( WW)) for (i=0;i<NN;i++) tmp[i]= WW[i]; if ( WW) delete[]  WW;  WW=tmp;  WW[0]=1; for (i=NN?NN:1,w= WW[i-1];i<n;i++){ w=modmul(w,W);  WW[i]=w; }
    tmp=new DWORD[n]; if ((NN)&&(iWW)) for (i=0;i<NN;i++) tmp[i]=iWW[i]; if (iWW) delete[] iWW; iWW=tmp; iWW[0]=1; for (i=NN?NN:1,w=iWW[i-1];i<n;i++){ w=modmul(w,iW); iWW[i]=w; }
    NN=n;
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT(DWORD *dst,WW,1);
//    NTT_fast(dst,W);
    }

//---------------------------------------------------------------------------
void fourier_NTT::iNTT(DWORD *dst,iWW,rN);
//    iNTT_slow(dst,W);
    }

//---------------------------------------------------------------------------
bool fourier_NTT::init(DWORD n)
    {
    // (max(src[])^2)*n < p else NTT overflow can ocur!!!
    r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
//    r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
//    r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
//    r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
     N=n;                // Size of vectors [DWORDs]
     W=modpow(r,L);  // Wn for NTT
    iW=modpow(r,p-1-L);  // Wn for INTT
    rN=modpow(n,p-2  );  // Scale for INTT
    _alloc(n>>1);        // Precompute W,iW powers
    return true;
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,w);

    // Reorder even,j+=2) dst[i]=src[j];

    // Recursion
    NTT_fast(src,w2);    // Even
    NTT_fast(src+n2,w2);    // Odd

    // Restore results
    for (w2=1,a1);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD i2)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,n2=n>>1;

    // Reorder even,j+=2) dst[i]=src[j];

    // Recursion
    i=i2<<1;
    NTT_fast(src,w2,i);    // Even
    NTT_fast(src+n2,i);    // Odd

    // Restore results
    for (i=0,w2+=i2)
        {
        a0=src[i];
        a1=modmul(src[j],*w2);
        dst[i]=modadd(a0,a;
    for (wj=1,w);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT::iNTT_slow(DWORD *dst,iW);
        }
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
    {
    if (a>p) a-=p;
    return a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,cy;
    //if (a>p) a-=p;
    //if (b>p) b-=p;
    d=a+b;
    cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
    if (cy ) d-=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    //if (a>p) a-=p;
    //if (b>p) b-=p;
    d=a-b;
    if (a<b) d+=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {    // b is not mod(p)!
    int i;
    DWORD d=1;
    //if (a>p) a-=p;
    for (i=0;i<32;i++)
        {
        d=modmul(d,a);
        b<<=1;
        }
    return d;
    }
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
#endif
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------

通过将 NTT_fast 分成两个函数,仍然有可能使用更少的堆垃圾.一个带有 WW[],另一个带有 iWW[],这导致递归调用中的一个参数减少.但我对它的期望并不高(仅限 32 位指针),而是有一个功能可以在未来更好地管理代码.许多函数现在处于休眠状态(用于测试),例如慢变体、mod 和较旧的快速函数(使用 w 参数代替 *w2,i2).

为避免大数据集溢出,将输入数字限制为 p/4 位. 其中 p 是每个 的位数NTT 元素,因此对于此 32 位版本,请使用最大 (32 位/4 -> 8 位) 输入值.

[edit3] 用于测试的简单字符串 bigint 乘法

//---------------------------------------------------------------------------
char* mul_NTT(const char *sx,const char *sy)
    {
    char *s;
    int i,k,n;
    // n = min power of 2 <= 2 max length(x,y)
    for (i=0;sx[i];i++); for (n=1;n<i;n<<=1);        i--;
    for (j=0;sx[j];j++); for (n=1;n<j;n<<=1); n<<=1; j--;
    DWORD *x,*y,*xx,*yy,a;
    x=new DWORD[n]; xx=new DWORD[n];
    y=new DWORD[n]; yy=new DWORD[n];

    // Zero padding
    for (k=0;i>=0;i--,k++) x[k]=sx[i]-'0'; for (;k<n;k++) x[k]=0;
    for (k=0;j>=0;j--,k++) y[k]=sy[j]-'0'; for (;k<n;k++) y[k]=0;

    //NTT
    fourier_NTT ntt;
    ntt.NTT(xx,n);
    ntt.NTT(yy,y);

    // Convolution
    for (i=0;i<n;i++) xx[i]=ntt.modmul(xx[i],yy[i]);

    //INTT
    ntt.iNTT(yy,xx);

    //suma
    a=0; s=new char[n+1]; for (i=0;i<n;i++) { a+=yy[i]; s[n-i-1]=(a%10)+'0'; a/=10; } s[n]=0;
    delete[] x; delete[] xx;
    delete[] y; delete[] yy;

    return s;
    }
//---------------------------------------------------------------------------

我使用AnsiString,所以我希望将它移植到char*,我没有做错.看起来它工作正常(与 AnsiString 版本相比).

  • sx,sy 是十进制整数
  • 返回分配的字符串(char*)=sx*sy
  • sx,sy are decadic integer numbers
  • Returns allocated string (char*)=sx*sy

每 32 位数据字只有 ~4 位,因此没有溢出的风险,但当然速度较慢.在我的 bignum 库中,我使用二进制表示,并为 NTT 每 32 位 WORD 使用 8 位 块.如果 N 很大,那么风险更大......

玩得开心

尚未找到解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

编程问答相关问答

是否可以将 Python 程序转换为 C/C++? 我需要实现几个算法,我不确定性能差距是否足以证明我在 C/C++ 中执行它时所经历的所有痛苦(我不擅长)).我想写一个简单的算法,并根据这样一个转换后的解决方案对其进行基准测
我想使用 NTT 进行快速平方(请参阅快速 bignum 平方计算),但即使对于非常大的数字……超过 12000 位.
以下代码: myQueue.enqueue(\'a\'); myQueue.enqueue(\'b\'); cout << myQueue.dequeue() << myQueue.dequeue();
据我所知,写时复制不是在 C++11 中实现符合标准的 std::string 的可行方法,但是当它最近在讨论中出现时,我发现我自己无法直接支持这种说法.
这篇文章的评论部分有一个关于使用 std::vector::reserve 的帖子() vs. std::vector::resize().
我了解 inline 本身是对编译器的建议,它可以自行决定是否内联函数,并且还会生成可链接的目标代码.
我最近遇到了一个问题 可以使用模数除法轻松解决,但输入是浮点数: 给定一个周期函数(例如sin)和一个只能在周期范围内计算它的计算机函数(例如[-π,π]),制作一个可以处理任何输入的函数.
我想了解某个函数在我的 C++ 程序中在 Linux 上执行所需的时间.之后,我想做一个速度比较.我看到了几个时间函数,但最终从 boost 得到了这个.时间:
微信公众号搜索 “ 程序精选 ” ,选择关注!
微信公众号搜 "程序精选"关注