programming-examples/java/Data_Structures/KdTreeRectQuery.java
2019-11-15 12:59:38 +01:00

138 lines
3.2 KiB
Java

import java.util.Random;
public class KdTreeRectQuery {
public static class Point {
int x, y;
public Point(int x, int y) {
this.x = x;
this.y = y;
}
}
int[] tx;
int[] ty;
int[] minx, miny, maxx, maxy;
int[] count;
public KdTreeRectQuery(Point[] points) {
int n = points.length;
tx = new int[n];
ty = new int[n];
minx = new int[n];
miny = new int[n];
maxx = new int[n];
maxy = new int[n];
count = new int[n];
build(0, n, true, points);
}
void build(int low, int high, boolean divX, Point[] points) {
if (low >= high)
return;
int mid = (low + high) >>> 1;
nth_element(points, low, high, mid, divX);
tx[mid] = points[mid].x;
ty[mid] = points[mid].y;
count[mid] = high - low;
minx[mid] = Integer.MAX_VALUE;
miny[mid] = Integer.MAX_VALUE;
maxx[mid] = Integer.MIN_VALUE;
maxy[mid] = Integer.MIN_VALUE;
for (int i = low; i < high; i++) {
minx[mid] = Math.min(minx[mid], points[i].x);
miny[mid] = Math.min(miny[mid], points[i].y);
maxx[mid] = Math.max(maxx[mid], points[i].x);
maxy[mid] = Math.max(maxy[mid], points[i].y);
}
build(low, mid, !divX, points);
build(mid + 1, high, !divX, points);
}
static final Random rnd = new Random(1);
// See: http://www.cplusplus.com/reference/algorithm/nth_element
static void nth_element(Point[] a, int low, int high, int n, boolean divX) {
while (true) {
int k = partition(a, low, high, low + rnd.nextInt(high - low), divX);
if (n < k)
high = k;
else if (n > k)
low = k + 1;
else
return;
}
}
static int partition(Point[] a, int fromInclusive, int toExclusive, int separatorIndex, boolean divX) {
int i = fromInclusive;
int j = toExclusive - 1;
if (i >= j) return j;
double separator = divX ? a[separatorIndex].x : a[separatorIndex].y;
swap(a, i++, separatorIndex);
while (i <= j) {
while (i <= j && (divX ? a[i].x : a[i].y) < separator)
++i;
while (i <= j && (divX ? a[j].x : a[j].y) > separator)
--j;
if (i >= j)
break;
swap(a, i++, j--);
}
swap(a, j, fromInclusive);
return j;
}
static void swap(Point[] a, int i, int j) {
Point t = a[i];
a[i] = a[j];
a[j] = t;
}
// number of points in [x1,x2] x [y1,y2]
public int count(int x1, int y1, int x2, int y2) {
return count(0, tx.length, x1, y1, x2, y2);
}
int count(int low, int high, int x1, int y1, int x2, int y2) {
if (low >= high)
return 0;
int mid = (low + high) >>> 1;
int ax = minx[mid];
int ay = miny[mid];
int bx = maxx[mid];
int by = maxy[mid];
if (ax > x2 || x1 > bx || ay > y2 || y1 > by)
return 0;
if (x1 <= ax && bx <= x2 && y1 <= ay && by <= y2)
return count[mid];
int res = 0;
res += count(low, mid, x1, y1, x2, y2);
res += count(mid + 1, high, x1, y1, x2, y2);
if (x1 <= tx[mid] && tx[mid] <= x2 && y1 <= ty[mid] && ty[mid] <= y2)
++res;
return res;
}
// Usage example
public static void main(String[] args) {
int[] x = {0, 10, 0, 10};
int[] y = {0, 10, 10, 0};
Point[] points = new Point[x.length];
for (int i = 0; i < points.length; i++)
points[i] = new Point(x[i], y[i]);
KdTreeRectQuery kdTree = new KdTreeRectQuery(points);
int count = kdTree.count(0, 0, 10, 10);
System.out.println(4 == count);
}
}