programming-examples/java/Data_Structures/Bwt.java

154 lines
4.2 KiB
Java
Raw Normal View History

2019-11-15 12:59:38 +01:00
import java.io.*;
import java.util.*;
// https://en.wikipedia.org/wiki/BurrowsWheeler_transform
public class Bwt {
public static TransformedData bwt(byte[] s) {
int n = s.length;
int[] sa = suffixArray(s);
byte[] last = new byte[n];
int position = 0;
for (int i = 0; i < n; i++) {
last[i] = s[(sa[i] + n - 1) % n];
if (sa[i] == 0)
position = i;
}
return new TransformedData(last, position);
}
public static class TransformedData {
final byte[] last;
final int position;
TransformedData(byte[] last, int position) {
this.last = last;
this.position = position;
}
}
// sort rotations of S in O(n*log(n))
static int[] suffixArray(byte[] S) {
int n = S.length;
Integer[] order = new Integer[n];
for (int i = 0; i < n; i++)
order[i] = i;
Arrays.sort(order, (a, b) -> Integer.compare(Byte.toUnsignedInt(S[a]), Byte.toUnsignedInt(S[b])));
int[] sa = new int[n];
int[] classes = new int[n];
for (int i = 0; i < n; i++) {
sa[i] = order[i];
classes[i] = S[i];
}
for (int len = 1; len < n; len *= 2) {
int[] c = classes.clone();
for (int i = 0; i < n; i++)
classes[sa[i]] = i > 0 && c[sa[i - 1]] == c[sa[i]] && c[(sa[i - 1] + len / 2) % n] == c[(sa[i] + len / 2) % n] ? classes[sa[i - 1]] : i;
int[] cnt = new int[n];
for (int i = 0; i < n; i++)
cnt[i] = i;
int[] s = sa.clone();
for (int i = 0; i < n; i++) {
int s1 = (s[i] - len + n) % n;
sa[cnt[classes[s1]]++] = s1;
}
}
return sa;
}
public static byte[] inverseBwt(byte[] last, int position) {
int[] cnt = new int[256];
for (byte b : last)
++cnt[Byte.toUnsignedInt(b)];
for (int i = 1; i < cnt.length; i++)
cnt[i] += cnt[i - 1];
int n = last.length;
int[] t = new int[n];
for (int i = n - 1; i >= 0; i--)
t[--cnt[Byte.toUnsignedInt(last[i])]] = i;
byte[] res = new byte[n];
int j = t[position];
for (int i = 0; i < n; i++) {
res[i] = last[j];
j = t[j];
}
return res;
}
static byte[] mtfEncode(byte[] s) {
byte[] table = new byte[256];
for (int i = 0; i < table.length; i++)
table[i] = (byte) i;
byte[] res = new byte[s.length];
for (int i = 0; i < s.length; i++) {
int pos;
for (pos = 0; table[pos] != s[i]; pos++) ;
res[i] = (byte) pos;
byte t = table[pos];
System.arraycopy(table, 0, table, 1, pos);
table[0] = t;
}
return res;
}
static byte[] mtfDecode(byte[] s) {
byte[] table = new byte[256];
for (int i = 0; i < table.length; i++)
table[i] = (byte) i;
byte[] res = new byte[s.length];
for (int i = 0; i < s.length; i++) {
int pos = Byte.toUnsignedInt(s[i]);
res[i] = table[pos];
byte t = table[pos];
System.arraycopy(table, 0, table, 1, pos);
table[0] = t;
}
return res;
}
// Usage example
public static void main(String[] args) throws Exception {
Random rnd = new Random(1);
for (int step = 0; step < 1000; step++) {
int n = rnd.nextInt(20) + 1;
byte[] s = new byte[n];
for (int i = 0; i < n; i++)
s[i] = (byte) (rnd.nextInt(256));
TransformedData transformedData = bwt(s);
byte[] s2 = inverseBwt(transformedData.last, transformedData.position);
if (!Arrays.equals(s, s2))
throw new RuntimeException();
}
FileInputStream fs = new FileInputStream("src/Bwt.java");
byte[] data = new byte[1_000_000];
int len = fs.read(data, 0, data.length);
data = Arrays.copyOf(data, len);
TransformedData bwt = bwt(data);
byte[] encoded = mtfEncode(bwt.last);
byte[] last2 = mtfDecode(encoded);
int[] encodedInts = new int[encoded.length];
for (int i = 0; i < encoded.length; i++)
encodedInts[i] = Byte.toUnsignedInt(encoded[i]);
byte[] data2 = inverseBwt(bwt.last, bwt.position);
System.out.println(Arrays.equals(data, data2));
System.out.println(Arrays.equals(bwt.last, last2));
Locale.setDefault(Locale.US);
System.out.printf("%d -> %.0f\n", len, optimalCompressedLength(encodedInts));
}
static double optimalCompressedLength(int[] a) {
int max = 0;
for (int x : a)
max = Math.max(max, x);
int[] freq = new int[max + 1];
for (int x : a)
++freq[x];
double optimalLength = 0;
for (int f : freq)
if (f > 0)
optimalLength += f * Math.log((double) a.length / f) / Math.log(2) / 8;
return optimalLength;
}
}