programming-examples/java/Data_Structures/FFT.java
2019-11-15 12:59:38 +01:00

112 lines
2.9 KiB
Java

import java.math.BigInteger;
import java.util.Random;
public class FFT {
// a.length == b.length == 2^x
public static void fft(double[] a, double[] b, boolean invert) {
int n = a.length;
int shift = 32 - Integer.numberOfTrailingZeros(n);
for (int i = 1; i < n; i++) {
int j = Integer.reverse(i << shift);
if (i < j) {
double temp = a[i];
a[i] = a[j];
a[j] = temp;
temp = b[i];
b[i] = b[j];
b[j] = temp;
}
}
for (int len = 2; len <= n; len <<= 1) {
int halfLen = len >> 1;
double angle = 2 * Math.PI / len * (invert ? -1 : 1);
double wLenA = Math.cos(angle);
double wLenB = Math.sin(angle);
for (int i = 0; i < n; i += len) {
double wA = 1;
double wB = 0;
for (int j = 0; j < halfLen; j++) {
double uA = a[i + j];
double uB = b[i + j];
double vA = a[i + j + halfLen] * wA - b[i + j + halfLen] * wB;
double vB = a[i + j + halfLen] * wB + b[i + j + halfLen] * wA;
a[i + j] = uA + vA;
b[i + j] = uB + vB;
a[i + j + halfLen] = uA - vA;
b[i + j + halfLen] = uB - vB;
double nextWA = wA * wLenA - wB * wLenB;
wB = wA * wLenB + wB * wLenA;
wA = nextWA;
}
}
}
if (invert) {
for (int i = 0; i < n; i++) {
a[i] /= n;
b[i] /= n;
}
}
}
public static long[] multiply(long[] a, long[] b) {
int resultSize = Integer.highestOneBit(Math.max(a.length, b.length) - 1) << 2;
resultSize = Math.max(resultSize, 2);
double[] aReal = new double[resultSize];
double[] aImaginary = new double[resultSize];
double[] bReal = new double[resultSize];
double[] bImaginary = new double[resultSize];
for (int i = 0; i < a.length; i++)
aReal[i] = a[i];
for (int i = 0; i < b.length; i++)
bReal[i] = b[i];
fft(aReal, aImaginary, false);
fft(bReal, bImaginary, false);
for (int i = 0; i < resultSize; i++) {
double real = aReal[i] * bReal[i] - aImaginary[i] * bImaginary[i];
aImaginary[i] = aImaginary[i] * bReal[i] + bImaginary[i] * aReal[i];
aReal[i] = real;
}
fft(aReal, aImaginary, true);
long[] result = new long[resultSize];
long carry = 0;
for (int i = 0; i < resultSize; i++) {
result[i] = Math.round(aReal[i]) + carry;
carry = result[i] / 10;
result[i] %= 10;
}
return result;
}
// random test
public static void main(String[] args) {
Random rnd = new Random(1);
for (int step = 0; step < 1000; step++) {
int n1 = rnd.nextInt(10) + 1;
String s1 = "";
long[] a = new long[n1];
for (int i = 0; i < n1; i++) {
int x = rnd.nextInt(10);
s1 = x + s1;
a[i] = x;
}
int n2 = rnd.nextInt(10) + 1;
String s2 = "";
long[] b = new long[n2];
for (int i = 0; i < n2; i++) {
int x = rnd.nextInt(10);
s2 = x + s2;
b[i] = x;
}
long[] res = multiply(a, b);
String s = "";
for (long v : res) {
s = v + s;
}
BigInteger mul = new BigInteger(s1).multiply(new BigInteger(s2));
if (!mul.equals(new BigInteger(s)))
throw new RuntimeException();
}
}
}