138 lines
3.2 KiB
Java
138 lines
3.2 KiB
Java
|
import java.util.Random;
|
||
|
|
||
|
public class KdTreeRectQuery {
|
||
|
|
||
|
public static class Point {
|
||
|
int x, y;
|
||
|
|
||
|
public Point(int x, int y) {
|
||
|
this.x = x;
|
||
|
this.y = y;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
int[] tx;
|
||
|
int[] ty;
|
||
|
int[] minx, miny, maxx, maxy;
|
||
|
int[] count;
|
||
|
|
||
|
public KdTreeRectQuery(Point[] points) {
|
||
|
int n = points.length;
|
||
|
tx = new int[n];
|
||
|
ty = new int[n];
|
||
|
minx = new int[n];
|
||
|
miny = new int[n];
|
||
|
maxx = new int[n];
|
||
|
maxy = new int[n];
|
||
|
count = new int[n];
|
||
|
build(0, n, true, points);
|
||
|
}
|
||
|
|
||
|
void build(int low, int high, boolean divX, Point[] points) {
|
||
|
if (low >= high)
|
||
|
return;
|
||
|
int mid = (low + high) >>> 1;
|
||
|
nth_element(points, low, high, mid, divX);
|
||
|
|
||
|
tx[mid] = points[mid].x;
|
||
|
ty[mid] = points[mid].y;
|
||
|
count[mid] = high - low;
|
||
|
|
||
|
minx[mid] = Integer.MAX_VALUE;
|
||
|
miny[mid] = Integer.MAX_VALUE;
|
||
|
maxx[mid] = Integer.MIN_VALUE;
|
||
|
maxy[mid] = Integer.MIN_VALUE;
|
||
|
for (int i = low; i < high; i++) {
|
||
|
minx[mid] = Math.min(minx[mid], points[i].x);
|
||
|
miny[mid] = Math.min(miny[mid], points[i].y);
|
||
|
maxx[mid] = Math.max(maxx[mid], points[i].x);
|
||
|
maxy[mid] = Math.max(maxy[mid], points[i].y);
|
||
|
}
|
||
|
|
||
|
build(low, mid, !divX, points);
|
||
|
build(mid + 1, high, !divX, points);
|
||
|
}
|
||
|
|
||
|
static final Random rnd = new Random(1);
|
||
|
|
||
|
// See: http://www.cplusplus.com/reference/algorithm/nth_element
|
||
|
static void nth_element(Point[] a, int low, int high, int n, boolean divX) {
|
||
|
while (true) {
|
||
|
int k = partition(a, low, high, low + rnd.nextInt(high - low), divX);
|
||
|
if (n < k)
|
||
|
high = k;
|
||
|
else if (n > k)
|
||
|
low = k + 1;
|
||
|
else
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static int partition(Point[] a, int fromInclusive, int toExclusive, int separatorIndex, boolean divX) {
|
||
|
int i = fromInclusive;
|
||
|
int j = toExclusive - 1;
|
||
|
if (i >= j) return j;
|
||
|
double separator = divX ? a[separatorIndex].x : a[separatorIndex].y;
|
||
|
swap(a, i++, separatorIndex);
|
||
|
while (i <= j) {
|
||
|
while (i <= j && (divX ? a[i].x : a[i].y) < separator)
|
||
|
++i;
|
||
|
while (i <= j && (divX ? a[j].x : a[j].y) > separator)
|
||
|
--j;
|
||
|
if (i >= j)
|
||
|
break;
|
||
|
swap(a, i++, j--);
|
||
|
}
|
||
|
swap(a, j, fromInclusive);
|
||
|
return j;
|
||
|
}
|
||
|
|
||
|
static void swap(Point[] a, int i, int j) {
|
||
|
Point t = a[i];
|
||
|
a[i] = a[j];
|
||
|
a[j] = t;
|
||
|
}
|
||
|
|
||
|
// number of points in [x1,x2] x [y1,y2]
|
||
|
public int count(int x1, int y1, int x2, int y2) {
|
||
|
return count(0, tx.length, x1, y1, x2, y2);
|
||
|
}
|
||
|
|
||
|
int count(int low, int high, int x1, int y1, int x2, int y2) {
|
||
|
if (low >= high)
|
||
|
return 0;
|
||
|
int mid = (low + high) >>> 1;
|
||
|
|
||
|
int ax = minx[mid];
|
||
|
int ay = miny[mid];
|
||
|
int bx = maxx[mid];
|
||
|
int by = maxy[mid];
|
||
|
|
||
|
if (ax > x2 || x1 > bx || ay > y2 || y1 > by)
|
||
|
return 0;
|
||
|
if (x1 <= ax && bx <= x2 && y1 <= ay && by <= y2)
|
||
|
return count[mid];
|
||
|
|
||
|
int res = 0;
|
||
|
res += count(low, mid, x1, y1, x2, y2);
|
||
|
res += count(mid + 1, high, x1, y1, x2, y2);
|
||
|
if (x1 <= tx[mid] && tx[mid] <= x2 && y1 <= ty[mid] && ty[mid] <= y2)
|
||
|
++res;
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
// Usage example
|
||
|
public static void main(String[] args) {
|
||
|
int[] x = {0, 10, 0, 10};
|
||
|
int[] y = {0, 10, 10, 0};
|
||
|
|
||
|
Point[] points = new Point[x.length];
|
||
|
for (int i = 0; i < points.length; i++)
|
||
|
points[i] = new Point(x[i], y[i]);
|
||
|
|
||
|
KdTreeRectQuery kdTree = new KdTreeRectQuery(points);
|
||
|
int count = kdTree.count(0, 0, 10, 10);
|
||
|
System.out.println(4 == count);
|
||
|
}
|
||
|
}
|