Модульная арифметика и оптимизация NTT (конечного поля DFT)

Я хотел использовать NTT для быстрого возведения в квадрат (см. Быстрое вычисление квадратов bignum ), но результат медленный даже для действительно больших чисел .. более 12000 бит.

Поэтому мой вопрос:

  1. Есть ли способ оптимизировать преобразование NTT? Я не хотел ускорять его параллелизмом (нитями); это только слой низкого уровня.
  2. Есть ли способ ускорить мою модульную арифметику?

Это мой (уже оптимизированный) исходный код в C ++ для NTT (он полный и 100% работает на C ++ без какой-либо необходимости для сторонних библиотек, а также должен быть streamобезопасным. Остерегайтесь, что исходный массив используется как временный !!! , Также он не может преобразовать массив в себя).

//--------------------------------------------------------------------------- class fourier_NTT // Number theoretic transform { public: DWORD r,L,p,N; DWORD W,iW,rN; fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; } // main interface void NTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n]) void INTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n]) // Helper functions bool init(DWORD n); // init r,L,p,W,iW,rN void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n]) // Only for testing void NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n]) void INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n]) // DWORD arithmetics DWORD shl(DWORD a); DWORD shr(DWORD a); // Modular arithmetics DWORD mod(DWORD a); DWORD modadd(DWORD a,DWORD b); DWORD modsub(DWORD a,DWORD b); DWORD modmul(DWORD a,DWORD b); DWORD modpow(DWORD a,DWORD b); }; //--------------------------------------------------------------------------- void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n) { if (n>0) init(n); NTT_fast(dst,src,N,W); // NTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- void fourier_NTT::INTT(DWORD *dst,DWORD *src,DWORD n) { if (n>0) init(n); NTT_fast(dst,src,N,iW); for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN); // INTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- bool fourier_NTT::init(DWORD n) { // (max(src[])^2)*n < p else NTT overflow can ocur !!! r=2; p=0xC0000001; if ((n0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit // r=2; p=0x78000001; if ((n0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit // r=2; p=0x00010001; if ((n0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit // r=2; p=0x0a000001; if ((n0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit N=n; // size of vectors [DWORDs] W=modpow(r, L); // Wn for NTT iW=modpow(r,p-1-L); // Wn for INTT rN=modpow(n,p-2 ); // scale for INTT return true; } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w) { if (n>1,w2=modmul(w,w); // reorder even,odd for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j]; for ( j=1;i<n ;i++,j+=2) dst[i]=src[j]; // recursion NTT_fast(src ,dst ,n2,w2); // even NTT_fast(src+n2,dst+n2,n2,w2); // odd // restore results for (w2=1,i=0,j=n2;i>1; for (wj=1,j=0;j<n;j++) { a=0; for (wi=1,i=0;i>1; for (wj=1,j=0;j<n;j++) { a=0; for (wi=1,i=0;i<n;i++) { a=modadd(a,modmul(wi,src[i])); wi=modmul(wi,wj); } dst[j]=modmul(a,rN); wj=modmul(wj,iW); } } //--------------------------------------------------------------------------- DWORD fourier_NTT::shl(DWORD a) { return (a<>1)&0x7FFFFFFF; } //--------------------------------------------------------------------------- DWORD fourier_NTT::mod(DWORD a) { DWORD bb; for (bb=p;(DWORD(a)>DWORD(bb))&&(!DWORD(bb&0x80000000));bb=shl(bb)); for (;;) { if (DWORD(a)>=DWORD(bb)) a-=bb; if (bb==p) break; bb =shr(bb); } return a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modadd(DWORD a,DWORD b) { DWORD d,cy; a=mod(a); b=mod(b); d=a+b; cy=(shr(a)+shr(b)+shr((a&1)+(b&1)))&0x80000000; if (cy) d-=p; if (DWORD(d)>=DWORD(p)) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modsub(DWORD a,DWORD b) { DWORD d; a=mod(a); b=mod(b); d=ab; if (DWORD(a)=DWORD(p)) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modmul(DWORD a,DWORD b) { // b bez orezania ! int i; DWORD d; a=mod(a); for (d=0,i=0;i<32;i++) { if (DWORD(a&1)) d=modadd(d,b); a=shr(a); b=modadd(b,b); } return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modpow(DWORD a,DWORD b) { // a,b bez orezania ! int i; DWORD d=1; for (i=0;i<32;i++) { d=modmul(d,d); if (DWORD(b&0x80000000)) d=modmul(d,a); b=shl(b); } return d; } //--------------------------------------------------------------------------- 

Пример использования моего NTT-classа:

 fourier_NTT ntt; const DWORD n=32 DWORD x[N]={0,1,2,3,....31},y[N]={32,33,34,35,...63},z[N]; ntt.NTT(z,x,N); // z[N]=NTT(x[N]), also init constants for N ntt.NTT(x,y); // x[N]=NTT(y[N]), no recompute of constants, use last N // modular convolution y[]=z[].x[] for (i=0;i<n;i++) y[i]=ntt.modmul(z[i],x[i]); ntt.INTT(x,y); // x[N]=INTT(y[N]), no recompute of constants, use last N // x[]=convolution of original x[].y[] 

Некоторые измерения перед оптимизацией (без classа NTT):

 a = 0.98765588997654321000 | 389*32 bits looped 1x times sqr1[ 3.177 ms ] fast sqr sqr2[ 720.419 ms ] NTT sqr mul1[ 5.588 ms ] simpe mul mul2[ 3.172 ms ] karatsuba mul mul3[ 1053.382 ms ] NTT mul 

Некоторые измерения после моих оптимизаций (текущий код, более низкий размер / количество параметров рекурсии и улучшенная модульная арифметика):

 a = 0.98765588997654321000 | 389*32 bits looped 1x times sqr1[ 3.214 ms ] fast sqr sqr2[ 208.298 ms ] NTT sqr mul1[ 5.564 ms ] simpe mul mul2[ 3.113 ms ] karatsuba mul mul3[ 302.740 ms ] NTT mul 

Проверьте времена NTT mul и NTT (мои оптимизации ускоряют его чуть более 3 раз). Это всего лишь 1-кратный цикл, поэтому он не очень точный (ошибка ~ 10%), но ускорение заметно даже сейчас (обычно я зацикливаю его 1000 и более, но мой NTT слишком медленный для этого).

Вы можете свободно использовать мой код … Просто держите мой ник и / или ссылку на эту страницу где-нибудь (rem в коде, readme.txt, о чем угодно). Я надеюсь, что это поможет … (я не видел источник C ++ для быстрых NTT в любом месте, поэтому мне пришлось написать его самостоятельно). Корни единства были протестированы для всех принятых N, см. fourier_NTT::init(DWORD n) .

PS: Для получения дополнительной информации о NTT см. Https://stackoverflow.com/a/18547575/2521214 . Этот код основан на моих сообщениях внутри этой ссылки.

[edit1:] Дальнейшие изменения в коде

Я сумел еще больше оптимизировать свою модульную арифметику, используя это modulo prime allways 0xC0000001 и устраняя ненужные вызовы. В результате ускорение стало потрясающим (более 40 раз), и умножение NTT происходит быстрее, чем karatsuba после примерно 1500 * 32-битного порога. Кстати, скорость моего NTT теперь такая же, как и мой оптимизированный DFFT на 64-битных удвоениях.

Некоторые измерения:

 a = 0.98765588997654321000 | 1553*32bits looped 10x times mul2[ 28.585 ms ] karatsuba mul mul3[ 26.311 ms ] NTT mul 

Новый исходный код для модульной арифметики:

 //--------------------------------------------------------------------------- DWORD fourier_NTT::mod(DWORD a) { if (a>p) a-=p; return a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modadd(DWORD a,DWORD b) { DWORD d,cy; if (a>p) a-=p; if (b>p) b-=p; d=a+b; cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000; if (cy ) d-=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modsub(DWORD a,DWORD b) { DWORD d; if (a>p) a-=p; if (b>p) b-=p; d=ab; if (ap) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modmul(DWORD a,DWORD b) { DWORD _a,_b,_p; _a=a; _b=b; _p=p; asm { mov eax,_a mov ebx,_b mul ebx // H(edx),L(eax) = eax * ebx mov ebx,_p div ebx // eax = H(edx),L(eax) / ebx mov _a,edx // edx = H(edx),L(eax) % ebx } return _a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modpow(DWORD a,DWORD b) { // b bez orezania! int i; DWORD d=1; if (a>p) a-=p; for (i=0;i<32;i++) { d=modmul(d,d); if (DWORD(b&0x80000000)) d=modmul(d,a); b<<=1; } return d; } //--------------------------------------------------------------------------- 

Как вы можете видеть, функции shl и shr больше не используются. Я думаю, что modpow можно дополнительно оптимизировать, но это не критическая функция, потому что она называется только очень немногим. Самая важная функция – modmul, и это, кажется, в лучшей форме.

Дальнейшие вопросы:

  • Есть ли другой способ ускорить NTT?
  • Являются ли мои оптимизации модульной арифметики безопасными? (Результаты кажутся одинаковыми, но я мог что-то пропустить).

[edit2] Новые оптимизации

 a = 0.99991970486 | 2000*32 bits looped 10x sqr1[ 13.908 ms ] fast sqr sqr2[ 13.649 ms ] NTT sqr mul1[ 19.726 ms ] simpe mul mul2[ 31.808 ms ] karatsuba mul mul3[ 19.373 ms ] NTT mul 

Я использовал все полезные материалы из всех ваших комментариев (спасибо за понимание).

ускорения:

  • + 2,5%, удалив ненужные версии безопасности (Mandalf The Beige)
  • + 34,9% с использованием предварительно вычисляемых полномочий W, iW (Mystical)
  • + 35% всего

Фактический полный исходный код:

 //--------------------------------------------------------------------------- //--- Number theoretic transforms: 2.03 ------------------------------------- //--------------------------------------------------------------------------- #ifndef _fourier_NTT_h #define _fourier_NTT_h //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- class fourier_NTT // Number theoretic transform { public: DWORD r,L,p,N; DWORD W,iW,rN; // W=(r^L) mod p, iW=inverse W, rN = inverse N DWORD *WW,*iWW,NN; // Precomputed (W,iW)^(0,..,NN-1) powers // Internals fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; WW=NULL; iWW=NULL; NN=0; } ~fourier_NTT(){ _free(); } void _free(); // Free precomputed W,iW powers tables void _alloc(DWORD n); // Allocate and precompute W,iW powers tables // Main interface void NTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n]) void iNTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n]) // Helper functions bool init(DWORD n); // init r,L,p,W,iW,rN void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n]) void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2); // Only for testing void NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n]) void iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n]) // Modular arithmetics (optimized, but it works only for p >= 0x80000000!!!) DWORD mod(DWORD a); DWORD modadd(DWORD a,DWORD b); DWORD modsub(DWORD a,DWORD b); DWORD modmul(DWORD a,DWORD b); DWORD modpow(DWORD a,DWORD b); }; //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- void fourier_NTT::_free() { NN=0; if ( WW) delete[] WW; WW=NULL; if (iWW) delete[] iWW; iWW=NULL; } //--------------------------------------------------------------------------- void fourier_NTT::_alloc(DWORD n) { if (n<=NN) return; DWORD *tmp,i,w; tmp=new DWORD[n]; if ((NN)&&( WW)) for (i=0;i<NN;i++) tmp[i]= WW[i]; if ( WW) delete[] WW; WW=tmp; WW[0]=1; for (i=NN?NN:1,w= WW[i-1];i<n;i++){ w=modmul(w, W); WW[i]=w; } tmp=new DWORD[n]; if ((NN)&&(iWW)) for (i=0;i<NN;i++) tmp[i]=iWW[i]; if (iWW) delete[] iWW; iWW=tmp; iWW[0]=1; for (i=NN?NN:1,w=iWW[i-1];i0) init(n); NTT_fast(dst,src,N,WW,1); // NTT_fast(dst,src,N,W); // NTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- void fourier_NTT::iNTT(DWORD *dst,DWORD *src,DWORD n) { if (n>0) init(n); NTT_fast(dst,src,N,iWW,1); // NTT_fast(dst,src,N,iW); for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN); // iNTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- bool fourier_NTT::init(DWORD n) { // (max(src[])^2)*n < p else NTT overflow can ocur!!! r=2; p=0xC0000001; if ((n0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit // r=2; p=0x78000001; if ((n0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit // r=2; p=0x00010001; if ((n0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit // r=2; p=0x0a000001; if ((n0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit N=n; // Size of vectors [DWORDs] W=modpow(r, L); // Wn for NTT iW=modpow(r,p-1-L); // Wn for INTT rN=modpow(n,p-2 ); // Scale for INTT _alloc(n>>1); // Precompute W,iW powers return true; } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w) { if (n>1,w2=modmul(w,w); // Reorder even,odd for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j]; for ( j=1;i<n ;i++,j+=2) dst[i]=src[j]; // Recursion NTT_fast(src ,dst ,n2,w2); // Even NTT_fast(src+n2,dst+n2,n2,w2); // Odd // Restore results for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w)) { a0=src[i]; a1=modmul(src[j],w2); dst[i]=modadd(a0,a1); dst[j]=modsub(a0,a1); } } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2) { if (n>1; // Reorder even,odd for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j]; for ( j=1;i<n ;i++,j+=2) dst[i]=src[j]; // Recursion i=i2<<1; NTT_fast(src ,dst ,n2,w2,i); // Even NTT_fast(src+n2,dst+n2,n2,w2,i); // Odd // Restore results for (i=0,j=n2;i<n2;i++,j++,w2+=i2) { a0=src[i]; a1=modmul(src[j],*w2); dst[i]=modadd(a0,a1); dst[j]=modsub(a0,a1); } } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w) { DWORD i,j,wj,wi,a; for (wj=1,j=0;j<n;j++) { a=0; for (wi=1,i=0;i<n;i++) { a=modadd(a,modmul(wi,src[i])); wi=modmul(wi,wj); } dst[j]=a; wj=modmul(wj,w); } } //--------------------------------------------------------------------------- void fourier_NTT::iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w) { DWORD i,j,wi=1,wj=1,a; for (wj=1,j=0;j<n;j++) { a=0; for (wi=1,i=0;ip) a-=p; return a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modadd(DWORD a,DWORD b) { DWORD d,cy; //if (a>p) a-=p; //if (b>p) b-=p; d=a+b; cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000; if (cy ) d-=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modsub(DWORD a,DWORD b) { DWORD d; //if (a>p) a-=p; //if (b>p) b-=p; d=ab; if (ap) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modmul(DWORD a,DWORD b) { DWORD _a,_b,_p; _a=a; _b=b; _p=p; asm { mov eax,_a mov ebx,_b mul ebx // H(edx),L(eax) = eax * ebx mov ebx,_p div ebx // eax = H(edx),L(eax) / ebx mov _a,edx // edx = H(edx),L(eax) % ebx } return _a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modpow(DWORD a,DWORD b) { // b is not mod(p)! int i; DWORD d=1; //if (a>p) a-=p; for (i=0;i<32;i++) { d=modmul(d,d); if (DWORD(b&0x80000000)) d=modmul(d,a); b<<=1; } return d; } //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- #endif //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- 

По-прежнему существует возможность использовать меньшее количество кучи, разделяя NTT_fast на две функции. Один с WW[] а другой с iWW[] что приводит к тому, что один параметр меньше в рекурсивных вызовах. Но я не ожидаю многого от него (только 32-разрядный указатель) и, скорее, имеет одну функцию для лучшего управления кодом в будущем. В настоящее время многие функции неактивны (для тестирования). Подобно медленным вариантам, mod и старшей быстрой функции (с параметром w вместо *w2,i2 ).

Чтобы избежать переполнения больших наборов данных, ограничьте входные номера до p/4 бит. Где p – количество бит на элемент NTT, поэтому для этой 32-разрядной версии используйте максимальные (32 bit/4 -> 8 bit) входные значения.

[edit3] Простое умножение bigint для тестирования

 //--------------------------------------------------------------------------- char* mul_NTT(const char *sx,const char *sy) { char *s; int i,j,k,n; // n = min power of 2 <= 2 max length(x,y) for (i=0;sx[i];i++); for (n=1;n<i;n<<=1); i--; for (j=0;sx[j];j++); for (n=1;n<j;n<<=1); n<=0;i--,k++) x[k]=sx[i]-'0'; for (;k=0;j--,k++) y[k]=sy[j]-'0'; for (;k<n;k++) y[k]=0; //NTT fourier_NTT ntt; ntt.NTT(xx,x,n); ntt.NTT(yy,y); // Convolution for (i=0;i<n;i++) xx[i]=ntt.modmul(xx[i],yy[i]); //INTT ntt.iNTT(yy,xx); //suma a=0; s=new char[n+1]; for (i=0;i<n;i++) { a+=yy[i]; s[ni-1]=(a%10)+'0'; a/=10; } s[n]=0; delete[] x; delete[] xx; delete[] y; delete[] yy; return s; } //--------------------------------------------------------------------------- 

Я использую AnsiString , поэтому я переношу его на char* надеюсь, я не ошибся. Похоже, что он работает правильно (по сравнению с версией AnsiString ).

  • sx,sy – десятичные целые числа
  • Возвращает выделенную строку (char*)=sx*sy

Это всего лишь ~ 4 бит на 32-битное слово данных, поэтому нет риска переполнения, но это, конечно, медленнее. В моем bignum lib я использую двоичное представление и использую 8 bit куски для 32-битного WORD для NTT . Более того, это опасно, если N большой …

Получайте удовольствие от этого

Во-первых, большое спасибо за публикацию и свободу использования. Я действительно ценю это.

Я смог использовать некоторые трюки для устранения некоторых ветвлений, перестроил основной цикл и модифицировал сборку и смог получить ускорение 1.35x.

Кроме того, я добавил условие препроцессора для 64 бит, поскольку Visual Studio не разрешает встроенную сборку в режиме 64 бит (спасибо Microsoft, не стесняйтесь сами по себе).

Что-то странное произошло, когда я оптимизировал функцию modsub (). Я переписал его с помощью бит-хаков, как я сделал modadd (что было быстрее). Но по какой-то причине мудрая версия modsub была медленнее. Не знаю, почему. Мог бы просто быть моим компьютером.

 // // Mandalf The Beige // Based on: // Spektre // http://stackoverflow.com/questions/18577076/modular-arithmetics-and-ntt-finite-field-dft-optimizations // // This code may be freely used however you choose, so long as it is accompanied by this notice. // #ifndef H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR #define H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR #include  #ifndef uint32 #define uint32 unsigned long int #endif #ifndef uint64 #define uint64 unsigned long long int #endif class fast_ntt // number theoretic transform { public: fast_ntt() { r = 0; L = 0; W = 0; iW = 0; rN = 0; } // main interface void NTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast NTT(uint32 src[n]) void INTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast INTT(uint32 src[n]) // helper functions private: bool init(uint32 n); // init r,L,p,W,iW,rN void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n]) void NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n]) void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w); // only for testing void NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = slow NTT(uint32 src[n]) void INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = slow INTT(uint32 src[n]) // uint32 arithmetics // modular arithmetics inline uint32 modadd(uint32 a, uint32 b); inline uint32 modsub(uint32 a, uint32 b); inline uint32 modmul(uint32 a, uint32 b); inline uint32 modpow(uint32 a, uint32 b); uint32 r, L, N;//, p; uint32 W, iW, rN; const uint32 p = 0xC0000001; }; //--------------------------------------------------------------------------- void fast_ntt::NTT(uint32 *dst, uint32 *src, uint32 n) { if (n > 0) { init(n); } NTT_fast(dst, src, N, W); // NTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- void fast_ntt::INTT(uint32 *dst, uint32 *src, uint32 n) { if (n > 0) { init(n); } NTT_fast(dst, src, N, iW); for (uint32 i = 0; i 0x10000000)) { r = 0; L = 0; W = 0; // p = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit // r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit // r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit // r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit N = n; // size of vectors [uint32s] W = modpow(r, L); // Wn for NTT iW = modpow(r, p - 1 - L); // Wn for INTT rN = modpow(n, p - 2); // scale for INTT return true; } //--------------------------------------------------------------------------- void fast_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w) { if(n > 1) { if(dst != src) { NTT_calc(dst, src, n, w); } else { uint32* temp = new uint32[n]; NTT_calc(temp, src, n, w); memcpy(dst, temp, n * sizeof(uint32)); delete [] temp; } } else if(n == 1) { dst[0] = src[0]; } } void fast_ntt::NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w) { if (n > 1) { uint32* temp = new uint32[n]; memcpy(temp, src, n * sizeof(uint32)); NTT_calc(dst, temp, n, w); delete[] temp; } else if (n == 1) { dst[0] = src[0]; } } void fast_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w) { if(n > 1) { uint32 i, j, a0, a1, n2 = n >> 1, w2 = modmul(w, w); // reorder even,odd for (i = 0, j = 0; i < n2; i++, j += 2) { dst[i] = src[j]; } for (j = 1; i < n; i++, j += 2) { dst[i] = src[j]; } // recursion if(n2 > 1) { NTT_calc(src, dst, n2, w2); // even NTT_calc(src + n2, dst + n2, n2, w2); // odd } else if(n2 == 1) { src[0] = dst[0]; src[1] = dst[1]; } // restore results w2 = 1, i = 0, j = n2; a0 = src[i]; a1 = src[j]; dst[i] = modadd(a0, a1); dst[j] = modsub(a0, a1); while (++i < n2) { w2 = modmul(w2, w); j++; a0 = src[i]; a1 = modmul(src[j], w2); dst[i] = modadd(a0, a1); dst[j] = modsub(a0, a1); } } } //--------------------------------------------------------------------------- void fast_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w) { uint32 i, j, wj, wi, a, n2 = n >> 1; for (wj = 1, j = 0; j < n; j++) { a = 0; for (wi = 1, i = 0; i < n; i++) { a = modadd(a, modmul(wi, src[i])); wi = modmul(wi, wj); } dst[j] = a; wj = modmul(wj, w); } } //--------------------------------------------------------------------------- void fast_ntt::INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w) { uint32 i, j, wi = 1, wj = 1, a, n2 = n >> 1; for (wj = 1, j = 0; j < n; j++) { a = 0; for (wi = 1, i = 0; i < n; i++) { a = modadd(a, modmul(wi, src[i])); wi = modmul(wi, wj); } dst[j] = modmul(a, rN); wj = modmul(wj, iW); } } //--------------------------------------------------------------------------- uint32 fast_ntt::modadd(uint32 a, uint32 b) { uint32 d; d = a + b; if(d < a) { d -= p; } if (d >= p) { d -= p; } return d; } //--------------------------------------------------------------------------- uint32 fast_ntt::modsub(uint32 a, uint32 b) { uint32 d; d = a - b; if (d > a) { d += p; } return d; } //--------------------------------------------------------------------------- uint32 fast_ntt::modmul(uint32 a, uint32 b) { uint32 _a = a; uint32 _b = b; // Original uint32 _p = p; __asm { mov eax, _a; mul _b; div _p; mov eax, edx; }; } uint32 fast_ntt::modpow(uint32 a, uint32 b) { //* uint64 D, M, A, P; P = p; A = a; M = 0llu - (b & 1); D = (M & A) | ((~M) & 1); while ((b >>= 1) != 0) { A = modmul(A, A); //A = (A * A) % P; if ((b & 1) == 1) { //D = (D * A) % P; D = modmul(D, A); } } return (uint32)D; } 

Новый modmul

 uint32 fast_ntt::modmul(uint32 a, uint32 b) { uint32 _a = a; uint32 _b = b; __asm { mov eax, a; mul b; mov ebx, eax; mov eax, 2863311530; mov ecx, edx; mul edx; shld edx, eax, 1; mov eax, 3221225473; mul edx; sub ebx, eax; mov eax, 3221225473; sbb ecx, edx; jc addback; neg ecx; and ecx, eax; sub ebx, ecx; sub ebx, eax; sbb edx, edx; and eax, edx; addback: add eax, ebx; }; } 

[EDIT] Spektre, основываясь на ваших отзывах, я изменил modadd & modsub на их оригинал. Я также понял, что внес некоторые изменения в рекурсивную функцию NTT, которой я не должен был иметь.

[EDIT2] Удалены ненужные операторы if и побитовые функции.

[EDIT3] Добавлена ​​новая встроенная assembly modmul.

  • Что такое микробиблиотека?
  • Как я могу ускорить запрос MySQL с большим смещением в предложении LIMIT?
  • Замена 32-битного счетчика циклов на 64-битные значения приводит к сумасшедшим отклонениям производительности
  • Алгоритм вычисления числа делителей заданного числа
  • Выполнение стресс-теста в веб-приложении?
  • Android webview slow
  • Отладка и производительность релиза
  • postgresql COUNT (DISTINCT ...) очень медленно
  • Насколько быстрее C ++, чем C #?
  • Игнорирование параметра NULL в T-SQL
  • Является ли оператор неравенства быстрее, чем оператор равенства?
  • Давайте будем гением компьютера.