// https://web.stanford.edu/~liszt90/acm/notebook.html#file16 // Fast Fourier Transform : Used in many applications(one is fast polynomial multiplication) #include #include #include struct cpx { cpx(){} cpx(double aa):a(aa),b(0){} cpx(double aa, double bb):a(aa),b(bb){} double a; double b; double modsq(void) const { return a * a + b * b; } cpx bar(void) const { return cpx(a, -b); } }; cpx operator +(cpx a, cpx b) { return cpx(a.a + b.a, a.b + b.b); } cpx operator *(cpx a, cpx b) { return cpx(a.a * b.a - a.b * b.b, a.a * b.b + a.b * b.a); } cpx operator /(cpx a, cpx b) { cpx r = a * b.bar(); return cpx(r.a / b.modsq(), r.b / b.modsq()); } cpx EXP(double theta) { return cpx(cos(theta),sin(theta)); } const double two_pi = 4 * acos(0); // in: input array // out: output array // step: {SET TO 1} (used internally) // size: length of the input/output {MUST BE A POWER OF 2} // dir: either plus or minus one (direction of the FFT) // RESULT: out[k] = \sum_{j=0}^{size - 1} in[j] * exp(dir * 2pi * i * j * k / size) void FFT(cpx *in, cpx *out, int step, int size, int dir) { if(size < 1) return; if(size == 1) { out[0] = in[0]; return; } FFT(in, out, step * 2, size / 2, dir); FFT(in + step, out + size / 2, step * 2, size / 2, dir); for(int i = 0 ; i < size / 2 ; i++) { cpx even = out[i]; cpx odd = out[i + size / 2]; out[i] = even + EXP(dir * two_pi * i / size) * odd; out[i + size / 2] = even + EXP(dir * two_pi * (i + size / 2) / size) * odd; } } // Usage: // f[0...N-1] and g[0..N-1] are numbers // Want to compute the convolution h, defined by // h[n] = sum of f[k]g[n-k] (k = 0, ..., N-1). // Here, the index is cyclic; f[-1] = f[N-1], f[-2] = f[N-2], etc. // Let F[0...N-1] be FFT(f), and similarly, define G and H. // The convolution theorem says H[n] = F[n]G[n] (element-wise product). // To compute h[] in O(N log N) time, do the following: // 1. Compute F and G (pass dir = 1 as the argument). // 2. Get H by element-wise multiplying F and G. // 3. Get h by taking the inverse FFT (use dir = -1 as the argument) // and *dividing by N*. DO NOT FORGET THIS SCALING FACTOR. int main(void) { printf("If rows come in identical pairs, then everything works.\n"); cpx a[8] = {0, 1, cpx(1,3), cpx(0,5), 1, 0, 2, 0}; cpx b[8] = {1, cpx(0,-2), cpx(0,1), 3, -1, -3, 1, -2}; cpx A[8]; cpx B[8]; FFT(a, A, 1, 8, 1); FFT(b, B, 1, 8, 1); for(int i = 0 ; i < 8 ; i++) { printf("%7.2lf%7.2lf", A[i].a, A[i].b); } printf("\n"); for(int i = 0 ; i < 8 ; i++) { cpx Ai(0,0); for(int j = 0 ; j < 8 ; j++) { Ai = Ai + a[j] * EXP(j * i * two_pi / 8); } printf("%7.2lf%7.2lf", Ai.a, Ai.b); } printf("\n"); cpx AB[8]; for(int i = 0 ; i < 8 ; i++) AB[i] = A[i] * B[i]; cpx aconvb[8]; FFT(AB, aconvb, 1, 8, -1); for(int i = 0 ; i < 8 ; i++) aconvb[i] = aconvb[i] / 8; for(int i = 0 ; i < 8 ; i++) { printf("%7.2lf%7.2lf", aconvb[i].a, aconvb[i].b); } printf("\n"); for(int i = 0 ; i < 8 ; i++) { cpx aconvbi(0,0); for(int j = 0 ; j < 8 ; j++) { aconvbi = aconvbi + a[j] * b[(8 + i - j) % 8]; } printf("%7.2lf%7.2lf", aconvbi.a, aconvbi.b); } printf("\n"); return 0; }