137 lines
3.1 KiB
Java
137 lines
3.1 KiB
Java
import java.util.Random;
|
|
|
|
// https://en.wikipedia.org/wiki/Metric_tree
|
|
public class MetricTree {
|
|
|
|
int[] x;
|
|
int[] y;
|
|
|
|
public MetricTree(int[] x, int[] y) {
|
|
this.x = x;
|
|
this.y = y;
|
|
build(0, x.length);
|
|
}
|
|
|
|
static final Random rnd = new Random(1);
|
|
|
|
void build(int low, int high) {
|
|
if (high - low <= 2)
|
|
return;
|
|
swap(low + rnd.nextInt(high - low), low);
|
|
int mid = (low + 1 + high) >>> 1;
|
|
nth_element(low + 1, high, mid);
|
|
|
|
build(low + 1, mid);
|
|
build(mid + 1, high);
|
|
}
|
|
|
|
// See http://www.cplusplus.com/reference/algorithm/nth_element
|
|
void nth_element(int low, int high, int n) {
|
|
int center = low - 1;
|
|
while (true) {
|
|
int k = partition(center, low, high, low + rnd.nextInt(high - low));
|
|
if (n < k)
|
|
high = k;
|
|
else if (n > k)
|
|
low = k + 1;
|
|
else
|
|
return;
|
|
}
|
|
}
|
|
|
|
int partition(int center, int fromInclusive, int toExclusive, int separatorIndex) {
|
|
int i = fromInclusive;
|
|
int j = toExclusive - 1;
|
|
if (i >= j) return j;
|
|
long separator = dist2(x[center], y[center], x[separatorIndex], y[separatorIndex]);
|
|
swap(i++, separatorIndex);
|
|
while (i <= j) {
|
|
while (i <= j && (dist2(x[center], y[center], x[i], y[i]) < separator))
|
|
++i;
|
|
while (i <= j && (dist2(x[center], y[center], x[j], y[j]) > separator))
|
|
--j;
|
|
if (i >= j)
|
|
break;
|
|
swap(i++, j--);
|
|
}
|
|
swap(j, fromInclusive);
|
|
return j;
|
|
}
|
|
|
|
static long dist2(int x1, int y1, int x2, int y2) {
|
|
int dx = x1 - x2;
|
|
int dy = y1 - y2;
|
|
return (long) dx * dx + (long) dy * dy;
|
|
}
|
|
|
|
void swap(int i, int j) {
|
|
int t = x[i];
|
|
x[i] = x[j];
|
|
x[j] = t;
|
|
t = y[i];
|
|
y[i] = y[j];
|
|
y[j] = t;
|
|
}
|
|
|
|
long bestDist;
|
|
int bestNode;
|
|
|
|
public int findNearestNeighbour(int px, int py) {
|
|
bestDist = Long.MAX_VALUE;
|
|
findNearestNeighbour(0, x.length, px, py);
|
|
return bestNode;
|
|
}
|
|
|
|
void findNearestNeighbour(int low, int high, int px, int py) {
|
|
if (high - low <= 0)
|
|
return;
|
|
long d2 = dist2(px, py, x[low], y[low]);
|
|
if (bestDist > d2) {
|
|
bestDist = d2;
|
|
bestNode = low;
|
|
}
|
|
|
|
if (high - low <= 1)
|
|
return;
|
|
|
|
int mid = (low + 1 + high) >>> 1;
|
|
long dist2 = dist2(px, py, x[mid], y[mid]);
|
|
if (bestDist > dist2) {
|
|
bestDist = dist2;
|
|
bestNode = mid;
|
|
}
|
|
|
|
double R = Math.sqrt(dist2(x[low], y[low], x[mid], y[mid]));
|
|
double r = Math.sqrt(bestDist);
|
|
double d = Math.sqrt(d2);
|
|
|
|
if (d <= R + r) {
|
|
findNearestNeighbour(low + 1, mid, px, py);
|
|
}
|
|
if (d >= R - r) {
|
|
findNearestNeighbour(mid + 1, high, px, py);
|
|
}
|
|
}
|
|
|
|
// random test
|
|
public static void main(String[] args) {
|
|
for (int step = 0; step < 100_000; step++) {
|
|
int qx = rnd.nextInt(100) - 50;
|
|
int qy = rnd.nextInt(100) - 50;
|
|
int n = rnd.nextInt(100) + 1;
|
|
int[] x = new int[n];
|
|
int[] y = new int[n];
|
|
long minDist = Long.MAX_VALUE;
|
|
for (int i = 0; i < n; i++) {
|
|
x[i] = rnd.nextInt(100) - 50;
|
|
y[i] = rnd.nextInt(100) - 50;
|
|
minDist = Math.min(minDist, dist2(qx, qy, x[i], y[i]));
|
|
}
|
|
MetricTree metricTree = new MetricTree(x, y);
|
|
int index = metricTree.findNearestNeighbour(qx, qy);
|
|
if (minDist != metricTree.bestDist || dist2(qx, qy, x[index], y[index]) != minDist)
|
|
throw new RuntimeException();
|
|
}
|
|
}
|
|
}
|