import java.util.Random; public class KdTreeRectQuery { public static class Point { int x, y; public Point(int x, int y) { this.x = x; this.y = y; } } int[] tx; int[] ty; int[] minx, miny, maxx, maxy; int[] count; public KdTreeRectQuery(Point[] points) { int n = points.length; tx = new int[n]; ty = new int[n]; minx = new int[n]; miny = new int[n]; maxx = new int[n]; maxy = new int[n]; count = new int[n]; build(0, n, true, points); } void build(int low, int high, boolean divX, Point[] points) { if (low >= high) return; int mid = (low + high) >>> 1; nth_element(points, low, high, mid, divX); tx[mid] = points[mid].x; ty[mid] = points[mid].y; count[mid] = high - low; minx[mid] = Integer.MAX_VALUE; miny[mid] = Integer.MAX_VALUE; maxx[mid] = Integer.MIN_VALUE; maxy[mid] = Integer.MIN_VALUE; for (int i = low; i < high; i++) { minx[mid] = Math.min(minx[mid], points[i].x); miny[mid] = Math.min(miny[mid], points[i].y); maxx[mid] = Math.max(maxx[mid], points[i].x); maxy[mid] = Math.max(maxy[mid], points[i].y); } build(low, mid, !divX, points); build(mid + 1, high, !divX, points); } static final Random rnd = new Random(1); // See: http://www.cplusplus.com/reference/algorithm/nth_element static void nth_element(Point[] a, int low, int high, int n, boolean divX) { while (true) { int k = partition(a, low, high, low + rnd.nextInt(high - low), divX); if (n < k) high = k; else if (n > k) low = k + 1; else return; } } static int partition(Point[] a, int fromInclusive, int toExclusive, int separatorIndex, boolean divX) { int i = fromInclusive; int j = toExclusive - 1; if (i >= j) return j; double separator = divX ? a[separatorIndex].x : a[separatorIndex].y; swap(a, i++, separatorIndex); while (i <= j) { while (i <= j && (divX ? a[i].x : a[i].y) < separator) ++i; while (i <= j && (divX ? a[j].x : a[j].y) > separator) --j; if (i >= j) break; swap(a, i++, j--); } swap(a, j, fromInclusive); return j; } static void swap(Point[] a, int i, int j) { Point t = a[i]; a[i] = a[j]; a[j] = t; } // number of points in [x1,x2] x [y1,y2] public int count(int x1, int y1, int x2, int y2) { return count(0, tx.length, x1, y1, x2, y2); } int count(int low, int high, int x1, int y1, int x2, int y2) { if (low >= high) return 0; int mid = (low + high) >>> 1; int ax = minx[mid]; int ay = miny[mid]; int bx = maxx[mid]; int by = maxy[mid]; if (ax > x2 || x1 > bx || ay > y2 || y1 > by) return 0; if (x1 <= ax && bx <= x2 && y1 <= ay && by <= y2) return count[mid]; int res = 0; res += count(low, mid, x1, y1, x2, y2); res += count(mid + 1, high, x1, y1, x2, y2); if (x1 <= tx[mid] && tx[mid] <= x2 && y1 <= ty[mid] && ty[mid] <= y2) ++res; return res; } // Usage example public static void main(String[] args) { int[] x = {0, 10, 0, 10}; int[] y = {0, 10, 10, 0}; Point[] points = new Point[x.length]; for (int i = 0; i < points.length; i++) points[i] = new Point(x[i], y[i]); KdTreeRectQuery kdTree = new KdTreeRectQuery(points); int count = kdTree.count(0, 0, 10, 10); System.out.println(4 == count); } }