programming-examples/java/Data_Structures/ArithmeticCoding.java
2019-11-15 12:59:38 +01:00

209 lines
5.1 KiB
Java

import java.io.*;
import java.util.*;
// https://en.wikipedia.org/wiki/Arithmetic_coding
public class ArithmeticCoding {
final int BITS = 30;
final int HIGHEST_BIT = 1 << (BITS - 1);
final int MASK = (1 << BITS) - 1;
final int END = 256;
long low, high;
int additionalBits;
long value;
int[] cumFreq;
int[] bits;
int bitsPos;
List<Boolean> encodedBits;
List<Integer> decodedBytes;
public int[] encode(int[] inputBytes) {
cumFreq = createFenwickTree(END + 1);
encodedBits = new ArrayList<>();
low = 0;
high = (1 << BITS) - 1;
additionalBits = 0;
for (int c : inputBytes)
encodeSymbol(c);
encodeSymbol(END);
outputBit(true);
int[] bits = new int[encodedBits.size()];
for (int i = 0; i < bits.length; i++)
bits[i] = encodedBits.get(i) ? 1 : 0;
return bits;
}
void encodeSymbol(int c) {
long range = high - low + 1;
high = low + range * sum(cumFreq, c) / sum(cumFreq, END) - 1;
low = low + range * sum(cumFreq, c - 1) / sum(cumFreq, END);
while (true) {
if ((low & HIGHEST_BIT) == (high & HIGHEST_BIT)) {
outputBit((high & HIGHEST_BIT) != 0);
low = (low << 1) & MASK;
high = ((high << 1) + 1) & MASK;
} else if (high - low < sum(cumFreq, END)) {
low = (low - (1 << (BITS - 2))) << 1;
high = ((high - (1 << (BITS - 2))) << 1) + 1;
++additionalBits;
} else {
break;
}
}
increment(cumFreq, c);
}
void outputBit(boolean bit) {
encodedBits.add(bit);
for (; additionalBits > 0; additionalBits--)
encodedBits.add(!bit);
}
public int[] decode(int[] bits) {
this.bits = bits;
cumFreq = createFenwickTree(END + 1);
decodedBytes = new ArrayList<>();
value = 0;
for (bitsPos = 0; bitsPos < BITS; bitsPos++)
value = (value << 1) + (bitsPos < bits.length ? bits[bitsPos] : 0);
low = 0;
high = (1 << BITS) - 1;
while (true) {
int c = decodeSymbol();
if (c == END)
break;
decodedBytes.add(c);
increment(cumFreq, c);
}
int[] bytes = new int[decodedBytes.size()];
for (int i = 0; i < bytes.length; i++)
bytes[i] = decodedBytes.get(i);
return bytes;
}
int decodeSymbol() {
int cum = (int) (((value - low + 1) * sum(cumFreq, END) - 1) / (high - low + 1));
int c = upper_bound(cumFreq, cum);
long range = high - low + 1;
high = low + range * sum(cumFreq, c) / sum(cumFreq, END) - 1;
low = low + range * sum(cumFreq, c - 1) / sum(cumFreq, END);
while (true) {
if ((low & HIGHEST_BIT) == (high & HIGHEST_BIT)) {
low = (low << 1) & MASK;
high = ((high << 1) + 1) & MASK;
int b = bitsPos < bits.length ? bits[bitsPos++] : 0;
value = ((value << 1) + b) & MASK;
} else if (high - low < sum(cumFreq, END)) {
low = (low - (1 << (BITS - 2))) << 1;
high = ((high - (1 << (BITS - 2))) << 1) + 1;
int b = bitsPos < bits.length ? bits[bitsPos++] : 0;
value = ((value - (1 << (BITS - 2))) << 1) + b;
} else {
break;
}
}
return c;
}
// T[i] += 1
static void increment(int[] t, int i) {
for (; i < t.length; i |= i + 1)
++t[i];
}
// sum[0..i]
static int sum(int[] t, int i) {
int res = 0;
for (; i >= 0; i = (i & (i + 1)) - 1)
res += t[i];
return res;
}
// Returns min(p|sum[0,p]>sum)
static int upper_bound(int[] t, int sum) {
int pos = -1;
for (int blockSize = Integer.highestOneBit(t.length); blockSize != 0; blockSize >>= 1) {
int nextPos = pos + blockSize;
if (nextPos < t.length && sum >= t[nextPos]) {
sum -= t[nextPos];
pos = nextPos;
}
}
return pos + 1;
}
static int[] createFenwickTree(int n) {
int[] res = new int[n];
for (int i = 0; i < n; i++) {
++res[i];
int j = i | (i + 1);
if (j < n)
res[j] += res[i];
}
return res;
}
// random tests
public static void main(String[] args) throws IOException {
ArithmeticCoding codec = new ArithmeticCoding();
int[] a = new int[1000_000];
int[] encodedBits = codec.encode(a);
System.out.println(a.length + " -> " + encodedBits.length / 8);
System.out.println(Arrays.equals(a, codec.decode(encodedBits)));
Random rnd = new Random();
for (int step = 0; step < 10_000; step++) {
int n = rnd.nextInt(100) + 1;
int[] inputBytes = rnd.ints(n, 0, 255).toArray();
encodedBits = codec.encode(inputBytes);
int[] decodedBytes = codec.decode(encodedBits);
if (!Arrays.equals(inputBytes, decodedBytes))
throw new RuntimeException();
}
FileInputStream fs = new FileInputStream("src/ArithmeticCoding.java");
byte[] buffer = new byte[10_000_000];
int len = fs.read(buffer, 0, buffer.length);
a = new int[len];
for (int i = 0; i < len; i++)
a[i] = Byte.toUnsignedInt(buffer[i]);
encodedBits = codec.encode(a);
Locale.setDefault(Locale.US);
System.out.printf("%d -> %d (%.0f)\n", a.length, encodedBits.length / 8, optimalCompressedLength(a));
System.out.println(Arrays.equals(a, codec.decode(encodedBits)));
}
static double optimalCompressedLength(int[] a) {
int max = 0;
for (int x : a)
max = Math.max(max, x);
int[] freq = new int[max + 1];
for (int x : a)
++freq[x];
double optimalLength = 0;
for (int f : freq)
if (f > 0)
optimalLength += f * Math.log((double) a.length / f) / Math.log(2) / 8;
return optimalLength;
}
}