106 lines
2.6 KiB
Java
106 lines
2.6 KiB
Java
import java.util.*;
|
|
import java.util.stream.IntStream;
|
|
|
|
public class NthElement {
|
|
|
|
// See: http://www.cplusplus.com/reference/algorithm/nth_element
|
|
// O(n) on average
|
|
public static void nth_element(int[] a, int low, int high, int n, Random rnd) {
|
|
while (true) {
|
|
int k = partition(a, low, high, low + rnd.nextInt(high - low));
|
|
if (n < k)
|
|
high = k;
|
|
else if (n > k)
|
|
low = k + 1;
|
|
else
|
|
return;
|
|
}
|
|
}
|
|
|
|
static int partition(int[] a, int fromInclusive, int toExclusive, int separatorIndex) {
|
|
int i = fromInclusive;
|
|
int j = toExclusive - 1;
|
|
if (i >= j) return j;
|
|
int separator = a[separatorIndex];
|
|
swap(a, i++, separatorIndex);
|
|
while (i <= j) {
|
|
while (i <= j && a[i] < separator)
|
|
++i;
|
|
while (i <= j && a[j] > separator)
|
|
--j;
|
|
if (i >= j)
|
|
break;
|
|
swap(a, i++, j--);
|
|
}
|
|
swap(a, j, fromInclusive);
|
|
return j;
|
|
}
|
|
|
|
static void swap(int[] a, int i, int j) {
|
|
int t = a[j];
|
|
a[j] = a[i];
|
|
a[i] = t;
|
|
}
|
|
|
|
// O(n) worst case. See Cormen et al
|
|
public static void nth_element2(int[] a, int low, int high, int n) {
|
|
if (high - low <= 1)
|
|
return;
|
|
while (true) {
|
|
int[] a5 = new int[(high - low + 4) / 5];
|
|
for (int i = low, cnt = 0; i < high; i += 5) {
|
|
int j = Math.min(i + 5, high);
|
|
for (int iteration = 0; iteration < 3; iteration++) {
|
|
for (int k = i; k + 1 < j; k++) {
|
|
if (a[k] > a[k + 1]) {
|
|
int t = a[k];
|
|
a[k] = a[k + 1];
|
|
a[k + 1] = t;
|
|
}
|
|
}
|
|
}
|
|
a5[cnt++] = a[(i + j) >>> 1];
|
|
}
|
|
nth_element2(a5, 0, a5.length, a5.length / 2);
|
|
int separatorIndex = IntStream.range(low, high).filter(i -> a[i] == a5[a5.length / 2]).findFirst().getAsInt();
|
|
int k = partition(a, low, high, separatorIndex);
|
|
if (n < k)
|
|
high = k;
|
|
else if (n > k)
|
|
low = k + 1;
|
|
else
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Random test
|
|
public static void main(String[] args) {
|
|
Random rnd = new Random(1);
|
|
int len = 1000_000;
|
|
nth_element(new int[len], 0, len, 0, rnd);
|
|
nth_element2(new int[len], 0, len, 0);
|
|
for (int step = 0; step < 100_000; step++) {
|
|
int n = rnd.nextInt(10) + 1;
|
|
int[] a = rnd.ints(n, 0, 10).toArray();
|
|
int[] b = a.clone();
|
|
int k = rnd.nextInt(n);
|
|
nth_element(a, 0, n, k, rnd);
|
|
nth_element2(b, 0, n, k);
|
|
int[] sa = a.clone();
|
|
Arrays.sort(sa);
|
|
int[] sb = b.clone();
|
|
Arrays.sort(sb);
|
|
if (!Arrays.equals(sa, sb))
|
|
throw new RuntimeException();
|
|
if (a[k] != sa[k] || b[k] != sb[k])
|
|
throw new RuntimeException();
|
|
for (int i = 0; i < n; i++) {
|
|
if (i < k && a[i] > a[k] || i > k && a[i] < a[k])
|
|
throw new RuntimeException();
|
|
if (i < k && b[i] > b[k] || i > k && b[i] < b[k])
|
|
throw new RuntimeException();
|
|
}
|
|
}
|
|
}
|
|
}
|